Skip to content

Commit

Permalink
update notebooks to new-style typed PRNG keys
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed Mar 7, 2024
1 parent 0bdbe76 commit 75a53f4
Show file tree
Hide file tree
Showing 18 changed files with 48 additions and 48 deletions.
8 changes: 4 additions & 4 deletions docs/notebooks/Common_Gotchas_in_JAX.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,7 @@
"source": [
"JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n",
"\n",
"The random state is described by two unsigned-int32s that we call a __key__:"
"The random state is described by a special array element that we call a __key__:"
]
},
{
Expand All @@ -1030,7 +1030,7 @@
],
"source": [
"from jax import random\n",
"key = random.PRNGKey(0)\n",
"key = random.key(0)\n",
"key"
]
},
Expand Down Expand Up @@ -2121,7 +2121,7 @@
}
],
"source": [
"x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)\n",
"x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)\n",
"x.dtype"
]
},
Expand Down Expand Up @@ -2188,7 +2188,7 @@
"source": [
"import jax.numpy as jnp\n",
"from jax import random\n",
"x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)\n",
"x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)\n",
"x.dtype # --> dtype('float64')"
]
},
Expand Down
8 changes: 4 additions & 4 deletions docs/notebooks/Common_Gotchas_in_JAX.md
Original file line number Diff line number Diff line change
Expand Up @@ -463,14 +463,14 @@ The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexcha

JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.

The random state is described by two unsigned-int32s that we call a __key__:
The random state is described by a special array element that we call a __key__:

```{code-cell} ipython3
:id: yPHE7KTWgAWs
:outputId: ae8af0ee-f19e-474e-81b6-45e894eb2fc3
from jax import random
key = random.PRNGKey(0)
key = random.key(0)
key
```

Expand Down Expand Up @@ -1071,7 +1071,7 @@ At the moment, JAX by default enforces single-precision numbers to mitigate the
:id: CNNGtzM3NDkO
:outputId: b422bb23-a784-44dc-f8c9-57f3b6c861b8
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype
```

Expand Down Expand Up @@ -1117,7 +1117,7 @@ We can then confirm that `x64` mode is enabled:
import jax.numpy as jnp
from jax import random
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype # --> dtype('float64')
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@
],
"source": [
"# Create an array of random values:\n",
"x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))\n",
"x = jax.random.normal(jax.random.key(0), (8192, 8192))\n",
"# and use jax.device_put to distribute it across devices:\n",
"y = jax.device_put(x, sharding.reshape(4, 2))\n",
"jax.debug.visualize_array_sharding(y)"
Expand Down Expand Up @@ -272,7 +272,7 @@
"outputs": [],
"source": [
"import jax\n",
"x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))"
"x = jax.random.normal(jax.random.key(0), (8192, 8192))"
]
},
{
Expand Down Expand Up @@ -1513,7 +1513,7 @@
},
"outputs": [],
"source": [
"x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))\n",
"x = jax.random.normal(jax.random.key(0), (8192, 8192))\n",
"x = jax.device_put(x, sharding.reshape(4, 2))"
]
},
Expand Down Expand Up @@ -1738,7 +1738,7 @@
"layer_sizes = [784, 8192, 8192, 8192, 10]\n",
"batch_size = 8192\n",
"\n",
"params, batch = init_model(jax.random.PRNGKey(0), layer_sizes, batch_size)"
"params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)"
]
},
{
Expand Down Expand Up @@ -2184,7 +2184,7 @@
" numbers = jax.random.uniform(key, x.shape)\n",
" return x + numbers\n",
"\n",
"key = jax.random.PRNGKey(42)\n",
"key = jax.random.key(42)\n",
"x_sharding = jax.sharding.PositionalSharding(jax.devices())\n",
"x = jax.device_put(jnp.arange(24), x_sharding)"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
:outputId: 3b518df8-5c29-4848-acc3-e41df939f30b
# Create an array of random values:
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
x = jax.random.normal(jax.random.key(0), (8192, 8192))
# and use jax.device_put to distribute it across devices:
y = jax.device_put(x, sharding.reshape(4, 2))
jax.debug.visualize_array_sharding(y)
Expand Down Expand Up @@ -144,7 +144,7 @@ For example, here's a value with a single-device `Sharding`:
:id: VmoX4SUp3vGJ
import jax
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
x = jax.random.normal(jax.random.key(0), (8192, 8192))
```

```{code-cell}
Expand Down Expand Up @@ -609,7 +609,7 @@ sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
```{code-cell}
:id: Q1wuDp-L3vGT
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
x = jax.random.normal(jax.random.key(0), (8192, 8192))
x = jax.device_put(x, sharding.reshape(4, 2))
```

Expand Down Expand Up @@ -720,7 +720,7 @@ def init_model(key, layer_sizes, batch_size):
layer_sizes = [784, 8192, 8192, 8192, 10]
batch_size = 8192
params, batch = init_model(jax.random.PRNGKey(0), layer_sizes, batch_size)
params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)
```

+++ {"id": "sJv_h0AS2drh"}
Expand Down Expand Up @@ -902,7 +902,7 @@ def f(key, x):
numbers = jax.random.uniform(key, x.shape)
return x + numbers
key = jax.random.PRNGKey(42)
key = jax.random.key(42)
x_sharding = jax.sharding.PositionalSharding(jax.devices())
x = jax.device_put(jnp.arange(24), x_sharding)
```
Expand Down
6 changes: 3 additions & 3 deletions docs/notebooks/Neural_Network_and_Data_Loading.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
"num_epochs = 8\n",
"batch_size = 128\n",
"n_targets = 10\n",
"params = init_network_params(layer_sizes, random.PRNGKey(0))"
"params = init_network_params(layer_sizes, random.key(0))"
]
},
{
Expand Down Expand Up @@ -150,7 +150,7 @@
],
"source": [
"# This works on single examples\n",
"random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))\n",
"random_flattened_image = random.normal(random.key(1), (28 * 28,))\n",
"preds = predict(params, random_flattened_image)\n",
"print(preds.shape)"
]
Expand All @@ -173,7 +173,7 @@
],
"source": [
"# Doesn't work with a batch\n",
"random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))\n",
"random_flattened_images = random.normal(random.key(1), (10, 28 * 28))\n",
"try:\n",
" preds = predict(params, random_flattened_images)\n",
"except TypeError:\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/notebooks/Neural_Network_and_Data_Loading.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))
params = init_network_params(layer_sizes, random.key(0))
```

+++ {"id": "BtoNk_yxWtIw"}
Expand Down Expand Up @@ -109,7 +109,7 @@ Let's check that our prediction function only works on single images.
:outputId: 9d3b29e8-fab3-4ecb-9f63-bc8c092f9006
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
random_flattened_image = random.normal(random.key(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)
```
Expand All @@ -119,7 +119,7 @@ print(preds.shape)
:outputId: d5d20211-b6da-44e9-f71e-946f2a9d0fc4
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
try:
preds = predict(params, random_flattened_images)
except TypeError:
Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
},
"outputs": [],
"source": [
"x = random.normal(random.PRNGKey(0), (5000, 5000))\n",
"x = random.normal(random.key(0), (5000, 5000))\n",
"def f(w, b, x):\n",
" return jnp.tanh(jnp.dot(x, w) + b)\n",
"fast_f = jit(f)"
Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks/Writing_custom_interpreters_in_Jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ JAX provides a NumPy-like API for numerical computing which can be used as is, b
```{code-cell} ipython3
:id: HmlMcICOcSXR
x = random.normal(random.PRNGKey(0), (5000, 5000))
x = random.normal(random.key(0), (5000, 5000))
def f(w, b, x):
return jnp.tanh(jnp.dot(x, w) + b)
fast_f = jit(f)
Expand Down
8 changes: 4 additions & 4 deletions docs/notebooks/autodiff_cookbook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"from jax import grad, jit, vmap\n",
"from jax import random\n",
"\n",
"key = random.PRNGKey(0)"
"key = random.key(0)"
]
},
{
Expand Down Expand Up @@ -1055,7 +1055,7 @@
" outs, = vmap(vjp_fun)(M)\n",
" return outs\n",
"\n",
"key = random.PRNGKey(0)\n",
"key = random.key(0)\n",
"num_covecs = 128\n",
"U = random.normal(key, (num_covecs,) + y.shape)\n",
"\n",
Expand Down Expand Up @@ -1306,7 +1306,7 @@
"outputs": [],
"source": [
"def check(seed):\n",
" key = random.PRNGKey(seed)\n",
" key = random.key(seed)\n",
"\n",
" # random coeffs for u and v\n",
" key, subkey = random.split(key)\n",
Expand Down Expand Up @@ -1399,7 +1399,7 @@
"outputs": [],
"source": [
"def check(seed):\n",
" key = random.PRNGKey(seed)\n",
" key = random.key(seed)\n",
"\n",
" # random coeffs for u and v\n",
" key, subkey = random.split(key)\n",
Expand Down
8 changes: 4 additions & 4 deletions docs/notebooks/autodiff_cookbook.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
key = random.PRNGKey(0)
key = random.key(0)
```

+++ {"id": "YxnjtAGN6vu2"}
Expand Down Expand Up @@ -614,7 +614,7 @@ def vmap_mjp(f, x, M):
outs, = vmap(vjp_fun)(M)
return outs
key = random.PRNGKey(0)
key = random.key(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)
Expand Down Expand Up @@ -770,7 +770,7 @@ Here's a check:
:id: BGZV__zupIMS
def check(seed):
key = random.PRNGKey(seed)
key = random.key(seed)
# random coeffs for u and v
key, subkey = random.split(key)
Expand Down Expand Up @@ -833,7 +833,7 @@ Here's a check of the VJP rules:
:id: 4J7edvIBttcU
def check(seed):
key = random.PRNGKey(seed)
key = random.key(seed)
# random coeffs for u and v
key, subkey = random.split(key)
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/convolutions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"import jax.numpy as jnp\n",
"import numpy as np\n",
"\n",
"key = random.PRNGKey(1701)\n",
"key = random.key(1701)\n",
"\n",
"x = jnp.linspace(0, 10, 500)\n",
"y = jnp.sin(x) + 0.2 * random.normal(key, shape=(500,))\n",
Expand Down Expand Up @@ -130,7 +130,7 @@
"ax[0].set_title('original')\n",
"\n",
"# Create a noisy version by adding random Gaussian noise\n",
"key = random.PRNGKey(1701)\n",
"key = random.key(1701)\n",
"noisy_image = image + 50 * random.normal(key, image.shape)\n",
"ax[1].imshow(noisy_image, cmap='binary_r')\n",
"ax[1].set_title('noisy')\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/convolutions.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ from jax import random
import jax.numpy as jnp
import numpy as np
key = random.PRNGKey(1701)
key = random.key(1701)
x = jnp.linspace(0, 10, 500)
y = jnp.sin(x) + 0.2 * random.normal(key, shape=(500,))
Expand Down Expand Up @@ -84,7 +84,7 @@ ax[0].imshow(image, cmap='binary_r')
ax[0].set_title('original')
# Create a noisy version by adding random Gaussian noise
key = random.PRNGKey(1701)
key = random.key(1701)
noisy_image = image + 50 * random.normal(key, image.shape)
ax[1].imshow(noisy_image, cmap='binary_r')
ax[1].set_title('noisy')
Expand Down
6 changes: 3 additions & 3 deletions docs/notebooks/neural_network_with_tfds_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
"num_epochs = 10\n",
"batch_size = 128\n",
"n_targets = 10\n",
"params = init_network_params(layer_sizes, random.PRNGKey(0))"
"params = init_network_params(layer_sizes, random.key(0))"
]
},
{
Expand Down Expand Up @@ -163,7 +163,7 @@
],
"source": [
"# This works on single examples\n",
"random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))\n",
"random_flattened_image = random.normal(random.key(1), (28 * 28,))\n",
"preds = predict(params, random_flattened_image)\n",
"print(preds.shape)"
]
Expand All @@ -186,7 +186,7 @@
],
"source": [
"# Doesn't work with a batch\n",
"random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))\n",
"random_flattened_images = random.normal(random.key(1), (10, 28 * 28))\n",
"try:\n",
" preds = predict(params, random_flattened_images)\n",
"except TypeError:\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/notebooks/neural_network_with_tfds_data.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))
params = init_network_params(layer_sizes, random.key(0))
```

+++ {"id": "BtoNk_yxWtIw"}
Expand Down Expand Up @@ -117,7 +117,7 @@ Let's check that our prediction function only works on single images.
:outputId: ce9d86ed-a830-4832-e04d-10d1abb1fb8a
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
random_flattened_image = random.normal(random.key(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)
```
Expand All @@ -127,7 +127,7 @@ print(preds.shape)
:outputId: f43bbc9d-bc8f-4168-ee7b-79ee9d33f245
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
try:
preds = predict(params, random_flattened_images)
except TypeError:
Expand Down
Loading

0 comments on commit 75a53f4

Please sign in to comment.