import os
import jax
from jax import numpy as jnp
from jax import random as jr
from matplotlib import pyplot as plt
from skimage.io import imread, imsaveA horizontal flip
img = imread('./The_Cat.jpg')plt.figure(figsize=(6, 8))
plt.imshow(img)
seed = 42
key = jr.key(seed)std_noise = jr.normal(key, img.shape)std_noise.min(), std_noise.max(), std_noise.mean(), std_noise.std()(Array(-5.012798, dtype=float32),
Array(4.638076, dtype=float32),
Array(5.15859e-05, dtype=float32),
Array(1.0000969, dtype=float32))
noise = 0.5 + 0.1 * std_noisenoise.min(), noise.max(), noise.mean(), noise.std()(Array(-0.00127977, dtype=float32),
Array(0.9638076, dtype=float32),
Array(0.5000051, dtype=float32),
Array(0.1000097, dtype=float32))
plt.imshow(noise)Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0012797713..0.9638076].

image = img / 255.0new_image = image + noise
new_image.min(), new_image.max(), new_image.mean(), new_image.std()(Array(0.11410853, dtype=float32),
Array(1.8716927, dtype=float32),
Array(0.95421576, dtype=float32),
Array(0.26887897, dtype=float32))
plt.imshow(new_image)Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.11410853..1.8716927].

image_flipped = image[:, ::-1, :]
plt.imshow(image_flipped)
image_flipped = jnp.fliplr(image)
plt.imshow(image_flipped)
Performing random augmentations
keyArray((), dtype=key<fry>) overlaying:
[ 0 42]
key1, key2 = jr.split(key)
key1, key2(Array((), dtype=key<fry>) overlaying:
[1832780943 270669613],
Array((), dtype=key<fry>) overlaying:
[ 64467757 2916123636])
def add_noise(image, rng_key):
noise = 0.5 + 0.1 * jr.normal(rng_key, image.shape)
new_image = image + noise
new_image = (new_image - new_image.min()) / (new_image.max() - new_image.min())
return new_imageaugmentations = [
add_noise,
lambda x, key: jnp.fliplr(x),
]def random_augmentation(image, augmentations, rng_key):
key1, key2 = jr.split(rng_key)
augmentation_idx = jr.randint(key=key1, minval=0,
maxval=len(augmentations), shape=())
augmented_image = jax.lax.switch(
augmentation_idx, augmentations, image, key2
)
return augmented_imagekey = jr.key(4242)image_aug = random_augmentation(image, augmentations, key)
plt.imshow(image_aug)
Numpy PRNG
import numpy as np
from numpy import random
from numpy.random import default_rngrng = default_rng()
vals = rng.normal(loc=0.5, scale=0.1, size=(3, 5))
more_vals = rng.normal(loc=0.5, scale=0.1, size=(3, 5))
vals, more_vals(array([[0.45879781, 0.43006299, 0.51154182, 0.50356447, 0.36374889],
[0.47998945, 0.72376906, 0.52692486, 0.63513868, 0.48407395],
[0.34634252, 0.44499043, 0.50929666, 0.52215354, 0.64615907]]),
array([[0.41411929, 0.61371454, 0.53851603, 0.52663525, 0.57510088],
[0.32308464, 0.48213926, 0.68905328, 0.40801101, 0.58335347],
[0.44979266, 0.4755453 , 0.53730226, 0.56361181, 0.57094472]]))
Controlling seed
random.seed(42)
vals = random.normal(loc=0.5, scale=0.1, size=(3, 5))
random.seed(42)
more_vals = random.normal(loc=0.5, scale=0.1, size=(3, 5))
random.seed(42)
even_more_vals = np.array(
[random.normal(loc=0.5, scale=0.1) for _ in range(3*5)]
).reshape((3, 5))valsarray([[0.54967142, 0.48617357, 0.56476885, 0.65230299, 0.47658466],
[0.4765863 , 0.65792128, 0.57674347, 0.45305256, 0.554256 ],
[0.45365823, 0.45342702, 0.52419623, 0.30867198, 0.32750822]])
more_valsarray([[0.54967142, 0.48617357, 0.56476885, 0.65230299, 0.47658466],
[0.4765863 , 0.65792128, 0.57674347, 0.45305256, 0.554256 ],
[0.45365823, 0.45342702, 0.52419623, 0.30867198, 0.32750822]])
even_more_valsarray([[0.54967142, 0.48617357, 0.56476885, 0.65230299, 0.47658466],
[0.4765863 , 0.65792128, 0.57674347, 0.45305256, 0.554256 ],
[0.45365823, 0.45342702, 0.52419623, 0.30867198, 0.32750822]])
Looking at the state:
random.seed(42)state = random.get_state()type(state), state[0](tuple, 'MT19937')
state[1].shape(624,)
state[2], state[3], state[4](624, 0, 0.0)
state('MT19937',
array([ 42, 3107752595, 1895908407, 3900362577, 3030691166,
4081230161, 2732361568, 1361238961, 3961642104, 867618704,
2837705690, 3281374275, 3928479052, 3691474744, 3088217429,
1769265762, 3769508895, 2731227933, 2930436685, 486258750,
1452990090, 3321835500, 3520974945, 2343938241, 928051207,
2811458012, 3391994544, 3688461242, 1372039449, 3706424981,
1717012300, 1728812672, 1688496645, 1203107765, 1648758310,
440890502, 1396092674, 626042708, 3853121610, 669844980,
2992565612, 310741647, 3820958101, 3474052697, 305511342,
2053450195, 705225224, 3836704087, 3293527636, 1140926340,
2738734251, 574359520, 1493564308, 269614846, 427919468,
2903547603, 2957214125, 181522756, 4137743374, 2557886044,
3399018834, 1348953650, 1575066973, 3837612427, 705360616,
4138204617, 1604205300, 1605197804, 590851525, 2371419134,
2530821810, 4183626679, 2872056396, 3895467791, 1156426758,
184917518, 2502875602, 2730245981, 3251099593, 2228829441,
2591075711, 3048691618, 3030004338, 1726207619, 993866654,
823585707, 936803789, 3180156728, 1191670842, 348221088,
988038522, 3281236861, 1153842962, 4152167900, 98291801,
816305276, 575746380, 1719541597, 2584648622, 1791391551,
3234806234, 413529090, 219961136, 4180088407, 1135264652,
3923811338, 2304598263, 762142228, 1980420688, 1225347938,
3657621885, 3762382117, 1157119598, 2556627260, 2276905960,
3857700293, 1903185298, 4258743924, 2078637161, 4160077183,
3569294948, 2138906140, 1346725611, 1473959117, 2798330104,
3785346335, 4103334026, 3448442764, 1142532843, 4278036691,
3071994514, 3474299731, 1121195796, 1536841934, 2132070705,
1064908919, 2840327803, 992870214, 2041326888, 2906112696,
4182466030, 1031463950, 703166484, 854266995, 4157971695,
4071962029, 2600094776, 2770410869, 3776335751, 2599879593,
2451043853, 2223709058, 2098813464, 4008111478, 2959232195,
3072496064, 2498909222, 4020139729, 785990520, 958060279,
4183949075, 2392404465, 533774465, 4092066952, 3967420027,
1726137853, 2907699474, 3158758391, 1460845905, 1323598137,
2446717890, 3004885867, 3447263769, 1378488047, 3172418196,
652839901, 1695052769, 226007057, 778836071, 1216725078,
655651335, 1850195064, 427367795, 800074262, 2241880422,
1713434925, 339981078, 1730571881, 672610244, 1952245009,
2729177102, 3516932475, 4032720152, 3177283432, 411893652,
2440235559, 3587427933, 43170267, 39225133, 3904203400,
1935961247, 3843123487, 1625453782, 1337993374, 2095455879,
3402219947, 634671126, 70868861, 3072823841, 851862432,
1828056818, 2794213810, 1222863684, 2164539406, 4249334162,
1380362252, 1512719097, 2773165233, 4063118969, 3041859837,
529421431, 563872464, 2478730478, 3168749051, 4132953373,
3922807735, 1124217574, 1970058502, 1744120743, 1906315107,
1074758800, 1611130652, 2878846041, 886823888, 1175456250,
1669874674, 2428820171, 1044308794, 3841962192, 138850094,
1239727126, 1753711876, 2194286827, 872797664, 4276240980,
690338888, 4087206238, 2279169960, 1117436170, 3344885072,
3127829945, 315537090, 3802787206, 4157203318, 1637047079,
3774106877, 3230158646, 1855823338, 1931415993, 667252379,
4288528171, 1587598285, 1096793218, 1916566454, 101891899,
2354644560, 3351208292, 1467125166, 2177732119, 4122299478,
3904084887, 2653591155, 4201043109, 2867379343, 2660555187,
3641744616, 4126452939, 326579197, 2697259239, 3365236848,
3007834487, 4118919490, 3306741951, 2285455175, 1956645973,
1879691841, 891565150, 1843460149, 2013381028, 819311674,
123282948, 1436558519, 1154343666, 206804484, 1650349242,
2142011886, 304163699, 2608574600, 2500624796, 2996744833,
2344192475, 3152512202, 165571606, 691170269, 1806226529,
568535825, 1243813863, 3068953841, 3843784723, 1540495237,
4246006858, 1303595780, 3288680241, 864868851, 819595545,
3230857496, 3574119395, 1545404573, 2970139338, 4292786727,
1803072884, 1374565738, 1736333177, 1978645403, 3962597126,
1068006206, 3458125500, 168085922, 1597587506, 2052497512,
1323596727, 2421372441, 1468386547, 3574947527, 3363915938,
860279252, 1309097460, 3065417722, 1490716202, 3476091722,
1669402145, 895071221, 1432690175, 3353592973, 149850974,
2789493615, 826939483, 666980418, 755367270, 3988951195,
21783894, 1924727373, 1699517788, 1152431122, 2593798113,
3522529522, 2797535609, 4018366956, 2350035889, 3010507270,
2832621820, 627979167, 997422629, 365587204, 2302500352,
1720920631, 689999548, 3713985947, 3267499624, 1971264680,
1981530399, 1662926921, 1833821660, 1422522022, 3141447769,
2727954526, 4172728772, 1787436028, 1902276939, 3145551277,
4207627911, 2497093521, 4111966589, 3929089589, 2253454030,
1069424637, 2165048659, 2848813944, 2435898022, 2546206777,
3864777677, 3107311565, 3776562483, 1040285049, 3171631943,
2404677828, 2522848682, 2930777301, 2831905121, 1436989598,
602730315, 664177960, 3959954010, 3116042160, 2881899726,
233404945, 4058465099, 1781994751, 485046222, 2776777695,
432082123, 1989128370, 86344507, 2510576356, 2194076764,
1742125237, 3715839140, 895100548, 147445686, 705462897,
2245325113, 1052295404, 1956014786, 2916055958, 1829369612,
2541711050, 1594343058, 3708804266, 150438233, 323857098,
294681952, 783931535, 606075163, 2427042904, 121207604,
3943199031, 1196785464, 1818211378, 1788241109, 3138862427,
2037307093, 2306750301, 1644605749, 165986111, 542190743,
486828112, 1757411662, 894543082, 4108143634, 1232805238,
3801632949, 3863166865, 713767006, 2091486427, 3174776264,
1157004409, 623072544, 1667151721, 3361539538, 696723008,
3247069452, 682044344, 1382136166, 1385645682, 4219951151,
2747881261, 2489355869, 786564174, 2040230554, 2967874556,
1414286092, 2677969656, 1393412218, 2216095072, 935533444,
3662643439, 3285199608, 3103672804, 522796956, 3952383595,
1928659176, 3397717710, 4278554051, 1984736931, 3559102926,
1878353094, 875578217, 2398931796, 2313634006, 1606027661,
2790634022, 2334166559, 1857067101, 666458681, 1626872683,
2155121857, 715449823, 1865157100, 2938814835, 4084911240,
45488075, 3474982924, 1750873825, 2246019159, 125388929,
1110287838, 652200437, 4212247716, 2702974687, 2963764270,
208692058, 3170393729, 1378248367, 752591527, 591629541,
2253399388, 2402291226, 3089656189, 3202324513, 3818308310,
2828131601, 2690672008, 3676629884, 1007739430, 4072247562,
3574795162, 518485611, 1889402182, 3687902739, 3410263649,
2790674620, 779455241, 3573984673, 3053204735, 4089925351,
789980683, 476440431, 3843536868, 2400661309, 3139919094,
1643266656, 113318754, 428163528, 2386492935, 3807242009,
574560611, 3174039857, 3774465602, 1164640969, 455942925,
1374407495, 2562304709, 1024844203, 521375136, 417432138,
1203241821, 2900988280, 2841030991, 2301700751, 369508560,
2396447808, 1891459643, 4225682708, 3930667846, 1518293357,
2697063889, 3113075061, 2411136298, 2836361984, 4105335811,
914081338, 2675982621, 1816939127, 1596754123, 1464603632,
1598478676, 1318403529, 4016663081, 2106416852, 2757323084,
2042842122, 1175184796, 2212339255, 1334626864, 3994484893,
3938045599, 2166620630, 3036360431, 397499085, 975931950,
1868702836, 3530424696, 3532548823, 2770836469, 3537418693,
3344319345, 3208552526, 1771170897, 4097379814, 3761572528,
2794194423, 706836738, 2953105956, 3446096217, 220984542,
309619699, 223913021, 3985142640, 1757616575, 2582763607,
4018329835, 1393278443, 4121569718, 2087146446, 4282833425,
807775617, 1396604749, 3571181413, 90301352, 2618014643,
2783561793, 1329389532, 836540831, 26719530], dtype=uint32),
624,
0,
0.0)
jax PRNG
key = jr.key(42)
type(key)jax._src.prng.PRNGKeyArray
keyArray((), dtype=key<fry>) overlaying:
[ 0 42]
vals = jr.normal(key, shape=(3,5))
valsArray([[-0.02830462, 0.46713185, 0.29570296, 0.15354592, -0.12403282],
[ 0.21692315, -1.4408789 , 0.7558599 , 0.52140963, 0.9101704 ],
[-0.3844966 , 1.1398233 , 1.4457862 , 1.0809066 , -0.05629321]], dtype=float32)
more_vals = jr.normal(key, shape=(3,5))
more_valsArray([[-0.02830462, 0.46713185, 0.29570296, 0.15354592, -0.12403282],
[ 0.21692315, -1.4408789 , 0.7558599 , 0.52140963, 0.9101704 ],
[-0.3844966 , 1.1398233 , 1.4457862 , 1.0809066 , -0.05629321]], dtype=float32)
Splitting the key to produce multiple keys
key = jr.key(42)
key1, key2 = jr.split(key, num=2)keyArray((), dtype=key<fry>) overlaying:
[ 0 42]
vals = jr.normal(key1, shape=(3, 5))
more_vals = jr.normal(key2, shape=(3, 5))key1, key2(Array((), dtype=key<fry>) overlaying:
[1832780943 270669613],
Array((), dtype=key<fry>) overlaying:
[ 64467757 2916123636])
vals, more_vals(Array([[ 0.07592554, -0.48634264, 1.2903206 , 0.5196119 , 0.30040437],
[ 0.31034866, 0.5761609 , -0.8074621 , -1.9883217 , 0.6395295 ],
[ 0.21763174, 0.00247425, 1.6645706 , 0.20313536, -0.02138225]], dtype=float32),
Array([[ 0.60576403, 0.7990441 , -0.908927 , -0.63525754, -1.2226585 ],
[-0.83226097, -0.47417238, -1.2504351 , -0.17678244, -0.04917514],
[-0.41177532, -0.39363015, 1.3116323 , 0.21555556, 0.41164538]], dtype=float32))
Generating many keys at once
key = jr.key(42)
key, *subkeys = jr.split(key, num=100)keyArray((), dtype=key<fry>) overlaying:
[1832780943 270669613]
len(subkeys)99
type(subkeys)list
Generating new keys using fold_in()
key = jr.key(42)
for i in range(5):
new_key = jr.fold_in(key, i)
print(new_key)
vals = jr.normal(new_key, shape=(3, 5))
# do something with valuesArray((), dtype=key<fry>) overlaying:
[1832780943 270669613]
Array((), dtype=key<fry>) overlaying:
[ 64467757 2916123636]
Array((), dtype=key<fry>) overlaying:
[2465931498 255383827]
Array((), dtype=key<fry>) overlaying:
[3134548294 894150801]
Array((), dtype=key<fry>) overlaying:
[2954079971 3276725750]
Using a string to fold into a new key
import hashlib
def my_hash(s):
return int(hashlib.sha1(s.encode()).hexdigest()[:8], base=16)
some_string = 'layer7_2'
some_int = my_hash(some_string)some_int2649017889
key = jr.key(42)
new_key = jr.fold_in(key, some_int)
new_keyArray((), dtype=key<fry>) overlaying:
[3110527424 3716265121]