Skip to content

Commit 75a53f4

Browse files
committedMar 7, 2024
update notebooks to new-style typed PRNG keys
1 parent 0bdbe76 commit 75a53f4

18 files changed

+48
-48
lines changed
 

‎docs/notebooks/Common_Gotchas_in_JAX.ipynb

+4-4
Original file line numberDiff line numberDiff line change
@@ -1006,7 +1006,7 @@
10061006
"source": [
10071007
"JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n",
10081008
"\n",
1009-
"The random state is described by two unsigned-int32s that we call a __key__:"
1009+
"The random state is described by a special array element that we call a __key__:"
10101010
]
10111011
},
10121012
{
@@ -1030,7 +1030,7 @@
10301030
],
10311031
"source": [
10321032
"from jax import random\n",
1033-
"key = random.PRNGKey(0)\n",
1033+
"key = random.key(0)\n",
10341034
"key"
10351035
]
10361036
},
@@ -2121,7 +2121,7 @@
21212121
}
21222122
],
21232123
"source": [
2124-
"x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)\n",
2124+
"x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)\n",
21252125
"x.dtype"
21262126
]
21272127
},
@@ -2188,7 +2188,7 @@
21882188
"source": [
21892189
"import jax.numpy as jnp\n",
21902190
"from jax import random\n",
2191-
"x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)\n",
2191+
"x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)\n",
21922192
"x.dtype # --> dtype('float64')"
21932193
]
21942194
},

‎docs/notebooks/Common_Gotchas_in_JAX.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -463,14 +463,14 @@ The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexcha
463463

464464
JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.
465465

466-
The random state is described by two unsigned-int32s that we call a __key__:
466+
The random state is described by a special array element that we call a __key__:
467467

468468
```{code-cell} ipython3
469469
:id: yPHE7KTWgAWs
470470
:outputId: ae8af0ee-f19e-474e-81b6-45e894eb2fc3
471471
472472
from jax import random
473-
key = random.PRNGKey(0)
473+
key = random.key(0)
474474
key
475475
```
476476

@@ -1071,7 +1071,7 @@ At the moment, JAX by default enforces single-precision numbers to mitigate the
10711071
:id: CNNGtzM3NDkO
10721072
:outputId: b422bb23-a784-44dc-f8c9-57f3b6c861b8
10731073
1074-
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
1074+
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
10751075
x.dtype
10761076
```
10771077

@@ -1117,7 +1117,7 @@ We can then confirm that `x64` mode is enabled:
11171117
11181118
import jax.numpy as jnp
11191119
from jax import random
1120-
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
1120+
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
11211121
x.dtype # --> dtype('float64')
11221122
```
11231123

‎docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb

+5-5
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@
131131
],
132132
"source": [
133133
"# Create an array of random values:\n",
134-
"x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))\n",
134+
"x = jax.random.normal(jax.random.key(0), (8192, 8192))\n",
135135
"# and use jax.device_put to distribute it across devices:\n",
136136
"y = jax.device_put(x, sharding.reshape(4, 2))\n",
137137
"jax.debug.visualize_array_sharding(y)"
@@ -272,7 +272,7 @@
272272
"outputs": [],
273273
"source": [
274274
"import jax\n",
275-
"x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))"
275+
"x = jax.random.normal(jax.random.key(0), (8192, 8192))"
276276
]
277277
},
278278
{
@@ -1513,7 +1513,7 @@
15131513
},
15141514
"outputs": [],
15151515
"source": [
1516-
"x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))\n",
1516+
"x = jax.random.normal(jax.random.key(0), (8192, 8192))\n",
15171517
"x = jax.device_put(x, sharding.reshape(4, 2))"
15181518
]
15191519
},
@@ -1738,7 +1738,7 @@
17381738
"layer_sizes = [784, 8192, 8192, 8192, 10]\n",
17391739
"batch_size = 8192\n",
17401740
"\n",
1741-
"params, batch = init_model(jax.random.PRNGKey(0), layer_sizes, batch_size)"
1741+
"params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)"
17421742
]
17431743
},
17441744
{
@@ -2184,7 +2184,7 @@
21842184
" numbers = jax.random.uniform(key, x.shape)\n",
21852185
" return x + numbers\n",
21862186
"\n",
2187-
"key = jax.random.PRNGKey(42)\n",
2187+
"key = jax.random.key(42)\n",
21882188
"x_sharding = jax.sharding.PositionalSharding(jax.devices())\n",
21892189
"x = jax.device_put(jnp.arange(24), x_sharding)"
21902190
]

‎docs/notebooks/Distributed_arrays_and_automatic_parallelization.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
8181
:outputId: 3b518df8-5c29-4848-acc3-e41df939f30b
8282
8383
# Create an array of random values:
84-
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
84+
x = jax.random.normal(jax.random.key(0), (8192, 8192))
8585
# and use jax.device_put to distribute it across devices:
8686
y = jax.device_put(x, sharding.reshape(4, 2))
8787
jax.debug.visualize_array_sharding(y)
@@ -144,7 +144,7 @@ For example, here's a value with a single-device `Sharding`:
144144
:id: VmoX4SUp3vGJ
145145
146146
import jax
147-
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
147+
x = jax.random.normal(jax.random.key(0), (8192, 8192))
148148
```
149149

150150
```{code-cell}
@@ -609,7 +609,7 @@ sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
609609
```{code-cell}
610610
:id: Q1wuDp-L3vGT
611611
612-
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
612+
x = jax.random.normal(jax.random.key(0), (8192, 8192))
613613
x = jax.device_put(x, sharding.reshape(4, 2))
614614
```
615615

@@ -720,7 +720,7 @@ def init_model(key, layer_sizes, batch_size):
720720
layer_sizes = [784, 8192, 8192, 8192, 10]
721721
batch_size = 8192
722722
723-
params, batch = init_model(jax.random.PRNGKey(0), layer_sizes, batch_size)
723+
params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)
724724
```
725725

726726
+++ {"id": "sJv_h0AS2drh"}
@@ -902,7 +902,7 @@ def f(key, x):
902902
numbers = jax.random.uniform(key, x.shape)
903903
return x + numbers
904904
905-
key = jax.random.PRNGKey(42)
905+
key = jax.random.key(42)
906906
x_sharding = jax.sharding.PositionalSharding(jax.devices())
907907
x = jax.device_put(jnp.arange(24), x_sharding)
908908
```

‎docs/notebooks/Neural_Network_and_Data_Loading.ipynb

+3-3
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
"num_epochs = 8\n",
8585
"batch_size = 128\n",
8686
"n_targets = 10\n",
87-
"params = init_network_params(layer_sizes, random.PRNGKey(0))"
87+
"params = init_network_params(layer_sizes, random.key(0))"
8888
]
8989
},
9090
{
@@ -150,7 +150,7 @@
150150
],
151151
"source": [
152152
"# This works on single examples\n",
153-
"random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))\n",
153+
"random_flattened_image = random.normal(random.key(1), (28 * 28,))\n",
154154
"preds = predict(params, random_flattened_image)\n",
155155
"print(preds.shape)"
156156
]
@@ -173,7 +173,7 @@
173173
],
174174
"source": [
175175
"# Doesn't work with a batch\n",
176-
"random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))\n",
176+
"random_flattened_images = random.normal(random.key(1), (10, 28 * 28))\n",
177177
"try:\n",
178178
" preds = predict(params, random_flattened_images)\n",
179179
"except TypeError:\n",

‎docs/notebooks/Neural_Network_and_Data_Loading.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ step_size = 0.01
7171
num_epochs = 8
7272
batch_size = 128
7373
n_targets = 10
74-
params = init_network_params(layer_sizes, random.PRNGKey(0))
74+
params = init_network_params(layer_sizes, random.key(0))
7575
```
7676

7777
+++ {"id": "BtoNk_yxWtIw"}
@@ -109,7 +109,7 @@ Let's check that our prediction function only works on single images.
109109
:outputId: 9d3b29e8-fab3-4ecb-9f63-bc8c092f9006
110110
111111
# This works on single examples
112-
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
112+
random_flattened_image = random.normal(random.key(1), (28 * 28,))
113113
preds = predict(params, random_flattened_image)
114114
print(preds.shape)
115115
```
@@ -119,7 +119,7 @@ print(preds.shape)
119119
:outputId: d5d20211-b6da-44e9-f71e-946f2a9d0fc4
120120
121121
# Doesn't work with a batch
122-
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
122+
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
123123
try:
124124
preds = predict(params, random_flattened_images)
125125
except TypeError:

‎docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
},
6767
"outputs": [],
6868
"source": [
69-
"x = random.normal(random.PRNGKey(0), (5000, 5000))\n",
69+
"x = random.normal(random.key(0), (5000, 5000))\n",
7070
"def f(w, b, x):\n",
7171
" return jnp.tanh(jnp.dot(x, w) + b)\n",
7272
"fast_f = jit(f)"

‎docs/notebooks/Writing_custom_interpreters_in_Jax.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ JAX provides a NumPy-like API for numerical computing which can be used as is, b
4848
```{code-cell} ipython3
4949
:id: HmlMcICOcSXR
5050
51-
x = random.normal(random.PRNGKey(0), (5000, 5000))
51+
x = random.normal(random.key(0), (5000, 5000))
5252
def f(w, b, x):
5353
return jnp.tanh(jnp.dot(x, w) + b)
5454
fast_f = jit(f)

‎docs/notebooks/autodiff_cookbook.ipynb

+4-4
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
"from jax import grad, jit, vmap\n",
2828
"from jax import random\n",
2929
"\n",
30-
"key = random.PRNGKey(0)"
30+
"key = random.key(0)"
3131
]
3232
},
3333
{
@@ -1055,7 +1055,7 @@
10551055
" outs, = vmap(vjp_fun)(M)\n",
10561056
" return outs\n",
10571057
"\n",
1058-
"key = random.PRNGKey(0)\n",
1058+
"key = random.key(0)\n",
10591059
"num_covecs = 128\n",
10601060
"U = random.normal(key, (num_covecs,) + y.shape)\n",
10611061
"\n",
@@ -1306,7 +1306,7 @@
13061306
"outputs": [],
13071307
"source": [
13081308
"def check(seed):\n",
1309-
" key = random.PRNGKey(seed)\n",
1309+
" key = random.key(seed)\n",
13101310
"\n",
13111311
" # random coeffs for u and v\n",
13121312
" key, subkey = random.split(key)\n",
@@ -1399,7 +1399,7 @@
13991399
"outputs": [],
14001400
"source": [
14011401
"def check(seed):\n",
1402-
" key = random.PRNGKey(seed)\n",
1402+
" key = random.key(seed)\n",
14031403
"\n",
14041404
" # random coeffs for u and v\n",
14051405
" key, subkey = random.split(key)\n",

‎docs/notebooks/autodiff_cookbook.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import jax.numpy as jnp
2929
from jax import grad, jit, vmap
3030
from jax import random
3131
32-
key = random.PRNGKey(0)
32+
key = random.key(0)
3333
```
3434

3535
+++ {"id": "YxnjtAGN6vu2"}
@@ -614,7 +614,7 @@ def vmap_mjp(f, x, M):
614614
outs, = vmap(vjp_fun)(M)
615615
return outs
616616
617-
key = random.PRNGKey(0)
617+
key = random.key(0)
618618
num_covecs = 128
619619
U = random.normal(key, (num_covecs,) + y.shape)
620620
@@ -770,7 +770,7 @@ Here's a check:
770770
:id: BGZV__zupIMS
771771
772772
def check(seed):
773-
key = random.PRNGKey(seed)
773+
key = random.key(seed)
774774
775775
# random coeffs for u and v
776776
key, subkey = random.split(key)
@@ -833,7 +833,7 @@ Here's a check of the VJP rules:
833833
:id: 4J7edvIBttcU
834834
835835
def check(seed):
836-
key = random.PRNGKey(seed)
836+
key = random.key(seed)
837837
838838
# random coeffs for u and v
839839
key, subkey = random.split(key)

‎docs/notebooks/convolutions.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
"import jax.numpy as jnp\n",
6161
"import numpy as np\n",
6262
"\n",
63-
"key = random.PRNGKey(1701)\n",
63+
"key = random.key(1701)\n",
6464
"\n",
6565
"x = jnp.linspace(0, 10, 500)\n",
6666
"y = jnp.sin(x) + 0.2 * random.normal(key, shape=(500,))\n",
@@ -130,7 +130,7 @@
130130
"ax[0].set_title('original')\n",
131131
"\n",
132132
"# Create a noisy version by adding random Gaussian noise\n",
133-
"key = random.PRNGKey(1701)\n",
133+
"key = random.key(1701)\n",
134134
"noisy_image = image + 50 * random.normal(key, image.shape)\n",
135135
"ax[1].imshow(noisy_image, cmap='binary_r')\n",
136136
"ax[1].set_title('noisy')\n",

‎docs/notebooks/convolutions.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ from jax import random
4343
import jax.numpy as jnp
4444
import numpy as np
4545
46-
key = random.PRNGKey(1701)
46+
key = random.key(1701)
4747
4848
x = jnp.linspace(0, 10, 500)
4949
y = jnp.sin(x) + 0.2 * random.normal(key, shape=(500,))
@@ -84,7 +84,7 @@ ax[0].imshow(image, cmap='binary_r')
8484
ax[0].set_title('original')
8585
8686
# Create a noisy version by adding random Gaussian noise
87-
key = random.PRNGKey(1701)
87+
key = random.key(1701)
8888
noisy_image = image + 50 * random.normal(key, image.shape)
8989
ax[1].imshow(noisy_image, cmap='binary_r')
9090
ax[1].set_title('noisy')

‎docs/notebooks/neural_network_with_tfds_data.ipynb

+3-3
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@
9797
"num_epochs = 10\n",
9898
"batch_size = 128\n",
9999
"n_targets = 10\n",
100-
"params = init_network_params(layer_sizes, random.PRNGKey(0))"
100+
"params = init_network_params(layer_sizes, random.key(0))"
101101
]
102102
},
103103
{
@@ -163,7 +163,7 @@
163163
],
164164
"source": [
165165
"# This works on single examples\n",
166-
"random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))\n",
166+
"random_flattened_image = random.normal(random.key(1), (28 * 28,))\n",
167167
"preds = predict(params, random_flattened_image)\n",
168168
"print(preds.shape)"
169169
]
@@ -186,7 +186,7 @@
186186
],
187187
"source": [
188188
"# Doesn't work with a batch\n",
189-
"random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))\n",
189+
"random_flattened_images = random.normal(random.key(1), (10, 28 * 28))\n",
190190
"try:\n",
191191
" preds = predict(params, random_flattened_images)\n",
192192
"except TypeError:\n",

‎docs/notebooks/neural_network_with_tfds_data.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ step_size = 0.01
7979
num_epochs = 10
8080
batch_size = 128
8181
n_targets = 10
82-
params = init_network_params(layer_sizes, random.PRNGKey(0))
82+
params = init_network_params(layer_sizes, random.key(0))
8383
```
8484

8585
+++ {"id": "BtoNk_yxWtIw"}
@@ -117,7 +117,7 @@ Let's check that our prediction function only works on single images.
117117
:outputId: ce9d86ed-a830-4832-e04d-10d1abb1fb8a
118118
119119
# This works on single examples
120-
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
120+
random_flattened_image = random.normal(random.key(1), (28 * 28,))
121121
preds = predict(params, random_flattened_image)
122122
print(preds.shape)
123123
```
@@ -127,7 +127,7 @@ print(preds.shape)
127127
:outputId: f43bbc9d-bc8f-4168-ee7b-79ee9d33f245
128128
129129
# Doesn't work with a batch
130-
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
130+
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
131131
try:
132132
preds = predict(params, random_flattened_images)
133133
except TypeError:

‎docs/notebooks/quickstart.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
}
8282
],
8383
"source": [
84-
"key = random.PRNGKey(0)\n",
84+
"key = random.key(0)\n",
8585
"x = random.normal(key, (10,))\n",
8686
"print(x)"
8787
]

‎docs/notebooks/quickstart.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ We'll be generating random data in the following examples. One big difference be
5959
:id: u0nseKZNqOoH
6060
:outputId: 03e20e21-376c-41bb-a6bb-57431823691b
6161
62-
key = random.PRNGKey(0)
62+
key = random.key(0)
6363
x = random.normal(key, (10,))
6464
print(x)
6565
```

‎docs/notebooks/vmapped_log_probs.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@
483483
"\n",
484484
"normal_sample = jax.jit(normal_sample, static_argnums=(1,))\n",
485485
"\n",
486-
"key = random.PRNGKey(10003)\n",
486+
"key = random.key(10003)\n",
487487
"\n",
488488
"beta_loc = jnp.zeros(num_features, jnp.float32)\n",
489489
"beta_log_scale = jnp.zeros(num_features, jnp.float32)\n",

‎docs/notebooks/vmapped_log_probs.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def normal_sample(key, shape):
210210
211211
normal_sample = jax.jit(normal_sample, static_argnums=(1,))
212212
213-
key = random.PRNGKey(10003)
213+
key = random.key(10003)
214214
215215
beta_loc = jnp.zeros(num_features, jnp.float32)
216216
beta_log_scale = jnp.zeros(num_features, jnp.float32)

0 commit comments

Comments
 (0)
Please sign in to comment.