Skip to content

Commit

Permalink
update jax-101 to new-style typed PRNG keys
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed Mar 7, 2024
1 parent 78fd4f1 commit 0fb4085
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 16 deletions.
10 changes: 5 additions & 5 deletions docs/jax-101/05-random-numbers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@
"source": [
"from jax import random\n",
"\n",
"key = random.PRNGKey(42)\n",
"key = random.key(42)\n",
"\n",
"print(key)"
]
Expand All @@ -293,7 +293,7 @@
"id": "XhFpKnW9F2nF"
},
"source": [
"A key is just an array of shape `(2,)`.\n",
"A single key is an array of scalar shape `()` and key element type.\n",
"\n",
"'Random key' is essentially just another word for 'random seed'. However, instead of setting it once as in NumPy, any call of a random function in JAX requires a key to be specified. Random functions consume the key, but do not modify it. Feeding the same key to a random function will always result in the same sample being generated:"
]
Expand Down Expand Up @@ -381,7 +381,7 @@
"source": [
"`split()` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. We keep one of the outputs as the `new_key`, and can safely use the unique extra key (called `subkey`) as input into a random function, and then discard it forever.\n",
"\n",
"If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNGKey twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.\n",
"If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNG key twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.\n",
"\n",
"It doesn't matter which part of the output of `split(key)` we call `key`, and which we call `subkey`. They are all pseudorandom numbers with equal status. The reason we use the key/subkey convention is to keep track of how they're consumed down the road. Subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later.\n",
"\n",
Expand Down Expand Up @@ -460,12 +460,12 @@
}
],
"source": [
"key = random.PRNGKey(42)\n",
"key = random.key(42)\n",
"subkeys = random.split(key, 3)\n",
"sequence = np.stack([random.normal(subkey) for subkey in subkeys])\n",
"print(\"individually:\", sequence)\n",
"\n",
"key = random.PRNGKey(42)\n",
"key = random.key(42)\n",
"print(\"all at once: \", random.normal(key, shape=(3,)))"
]
},
Expand Down
10 changes: 5 additions & 5 deletions docs/jax-101/05-random-numbers.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,14 @@ To avoid this issue, JAX does not use a global state. Instead, random functions
from jax import random
key = random.PRNGKey(42)
key = random.key(42)
print(key)
```

+++ {"id": "XhFpKnW9F2nF"}

A key is just an array of shape `(2,)`.
A single key is an array of scalar shape `()` and key element type.

'Random key' is essentially just another word for 'random seed'. However, instead of setting it once as in NumPy, any call of a random function in JAX requires a key to be specified. Random functions consume the key, but do not modify it. Feeding the same key to a random function will always result in the same sample being generated:

Expand Down Expand Up @@ -201,7 +201,7 @@ key = new_key # If we wanted to do this again, we would use new_key as the key.

`split()` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. We keep one of the outputs as the `new_key`, and can safely use the unique extra key (called `subkey`) as input into a random function, and then discard it forever.

If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNGKey twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.
If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNG key twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.

It doesn't matter which part of the output of `split(key)` we call `key`, and which we call `subkey`. They are all pseudorandom numbers with equal status. The reason we use the key/subkey convention is to keep track of how they're consumed down the road. Subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later.

Expand Down Expand Up @@ -240,12 +240,12 @@ In the example below, sampling 3 values out of a normal distribution individuall
:id: 4nB_TA54D-HT
:outputId: 2f259f63-3c45-46c8-f597-4e53dc63cb56
key = random.PRNGKey(42)
key = random.key(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)
key = random.PRNGKey(42)
key = random.key(42)
print("all at once: ", random.normal(key, shape=(3,)))
```

Expand Down
2 changes: 1 addition & 1 deletion docs/jax-101/06-parallelism.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@
"ys = xs * true_w + true_b + noise\n",
"\n",
"# Initialise parameters and replicate across devices.\n",
"params = init(jax.random.PRNGKey(123))\n",
"params = init(jax.random.key(123))\n",
"n_devices = jax.local_device_count()\n",
"replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)"
]
Expand Down
2 changes: 1 addition & 1 deletion docs/jax-101/06-parallelism.md
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ noise = 0.5 * np.random.normal(size=(128, 1))
ys = xs * true_w + true_b + noise
# Initialise parameters and replicate across devices.
params = init(jax.random.PRNGKey(123))
params = init(jax.random.key(123))
n_devices = jax.local_device_count()
replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)
```
Expand Down
4 changes: 2 additions & 2 deletions docs/jax-101/07-state.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@
"\n",
"In our case, the `CounterV2` class is nothing more than a namespace bringing all the functions that use `CounterState` into one location. Exercise for the reader: do you think it makes sense to keep it as a class?\n",
"\n",
"Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNGKey."
"Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNG key."
]
},
{
Expand Down Expand Up @@ -351,7 +351,7 @@
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"rng = jax.random.PRNGKey(42)\n",
"rng = jax.random.key(42)\n",
"\n",
"# Generate true data from y = w*x + b + noise\n",
"true_w, true_b = 2, -1\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/jax-101/07-state.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ Notice that the need for a class becomes less clear once we have rewritten it th

In our case, the `CounterV2` class is nothing more than a namespace bringing all the functions that use `CounterState` into one location. Exercise for the reader: do you think it makes sense to keep it as a class?

Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNGKey.
Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNG key.

+++ {"id": "I2SqRx14_z98"}

Expand Down Expand Up @@ -233,7 +233,7 @@ Notice that we manually pipe the params in and out of the update function.
import matplotlib.pyplot as plt
rng = jax.random.PRNGKey(42)
rng = jax.random.key(42)
# Generate true data from y = w*x + b + noise
true_w, true_b = 2, -1
Expand Down

0 comments on commit 0fb4085

Please sign in to comment.