Skip to content

Commit 0302e4c

Browse files
author
jax authors
committed
Merge pull request #17741 from froystig:new-style-key-docs
PiperOrigin-RevId: 614080080
2 parents c7bc95d + 721ca3f commit 0302e4c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+168
-128
lines changed

docs/Custom_Operation_for_GPUs.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ per_core_batch_size=4
304304
seq_len=512
305305
emb_dim=512
306306
x = jax.random.normal(
307-
jax.random.PRNGKey(0),
307+
jax.random.key(0),
308308
shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim),
309309
dtype=jnp.bfloat16,
310310
)
@@ -1049,7 +1049,7 @@ per_core_batch_size=4
10491049
seq_len=512
10501050
emb_dim=512
10511051
x = jax.random.normal(
1052-
jax.random.PRNGKey(0),
1052+
jax.random.key(0),
10531053
shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim),
10541054
dtype=jnp.bfloat16,
10551055
)

docs/async_dispatch.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ program:
99
>>> import numpy as np
1010
>>> import jax.numpy as jnp
1111
>>> from jax import random
12-
>>> x = random.uniform(random.PRNGKey(0), (1000, 1000))
12+
>>> x = random.uniform(random.key(0), (1000, 1000))
1313
>>> # Printing the result (i.e. evaluating `repr(result)` or `str(result)`)
1414
>>> # will block until the value is ready.
1515
>>> jnp.dot(x, x) + 3. # doctest: +SKIP

docs/device_memory_profiling.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def func2(x):
5959
y = func1(x)
6060
return y, jnp.tile(x, 10) + 1
6161

62-
x = jax.random.normal(jax.random.PRNGKey(42), (1000, 1000))
62+
x = jax.random.normal(jax.random.key(42), (1000, 1000))
6363
y, z = func2(x)
6464

6565
z.block_until_ready()
@@ -107,14 +107,14 @@ import jax.numpy as jnp
107107
import jax.profiler
108108

109109
def afunction():
110-
return jax.random.normal(jax.random.PRNGKey(77), (1000000,))
110+
return jax.random.normal(jax.random.key(77), (1000000,))
111111

112112
z = afunction()
113113

114114
def anotherfunc():
115115
arrays = []
116116
for i in range(1, 10):
117-
x = jax.random.normal(jax.random.PRNGKey(42), (i, 10000))
117+
x = jax.random.normal(jax.random.key(42), (i, 10000))
118118
arrays.append(x)
119119
x.block_until_ready()
120120
jax.profiler.save_device_memory_profile(f"memory{i}.prof")

docs/jax-101/05-random-numbers.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@
282282
"source": [
283283
"from jax import random\n",
284284
"\n",
285-
"key = random.PRNGKey(42)\n",
285+
"key = random.key(42)\n",
286286
"\n",
287287
"print(key)"
288288
]
@@ -293,7 +293,7 @@
293293
"id": "XhFpKnW9F2nF"
294294
},
295295
"source": [
296-
"A key is just an array of shape `(2,)`.\n",
296+
"A single key is an array of scalar shape `()` and key element type.\n",
297297
"\n",
298298
"'Random key' is essentially just another word for 'random seed'. However, instead of setting it once as in NumPy, any call of a random function in JAX requires a key to be specified. Random functions consume the key, but do not modify it. Feeding the same key to a random function will always result in the same sample being generated:"
299299
]
@@ -381,7 +381,7 @@
381381
"source": [
382382
"`split()` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. We keep one of the outputs as the `new_key`, and can safely use the unique extra key (called `subkey`) as input into a random function, and then discard it forever.\n",
383383
"\n",
384-
"If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNGKey twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.\n",
384+
"If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNG key twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.\n",
385385
"\n",
386386
"It doesn't matter which part of the output of `split(key)` we call `key`, and which we call `subkey`. They are all pseudorandom numbers with equal status. The reason we use the key/subkey convention is to keep track of how they're consumed down the road. Subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later.\n",
387387
"\n",
@@ -460,12 +460,12 @@
460460
}
461461
],
462462
"source": [
463-
"key = random.PRNGKey(42)\n",
463+
"key = random.key(42)\n",
464464
"subkeys = random.split(key, 3)\n",
465465
"sequence = np.stack([random.normal(subkey) for subkey in subkeys])\n",
466466
"print(\"individually:\", sequence)\n",
467467
"\n",
468-
"key = random.PRNGKey(42)\n",
468+
"key = random.key(42)\n",
469469
"print(\"all at once: \", random.normal(key, shape=(3,)))"
470470
]
471471
},

docs/jax-101/05-random-numbers.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,14 @@ To avoid this issue, JAX does not use a global state. Instead, random functions
150150
151151
from jax import random
152152
153-
key = random.PRNGKey(42)
153+
key = random.key(42)
154154
155155
print(key)
156156
```
157157

158158
+++ {"id": "XhFpKnW9F2nF"}
159159

160-
A key is just an array of shape `(2,)`.
160+
A single key is an array of scalar shape `()` and key element type.
161161

162162
'Random key' is essentially just another word for 'random seed'. However, instead of setting it once as in NumPy, any call of a random function in JAX requires a key to be specified. Random functions consume the key, but do not modify it. Feeding the same key to a random function will always result in the same sample being generated:
163163

@@ -201,7 +201,7 @@ key = new_key # If we wanted to do this again, we would use new_key as the key.
201201

202202
`split()` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. We keep one of the outputs as the `new_key`, and can safely use the unique extra key (called `subkey`) as input into a random function, and then discard it forever.
203203

204-
If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNGKey twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.
204+
If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNG key twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.
205205

206206
It doesn't matter which part of the output of `split(key)` we call `key`, and which we call `subkey`. They are all pseudorandom numbers with equal status. The reason we use the key/subkey convention is to keep track of how they're consumed down the road. Subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later.
207207

@@ -240,12 +240,12 @@ In the example below, sampling 3 values out of a normal distribution individuall
240240
:id: 4nB_TA54D-HT
241241
:outputId: 2f259f63-3c45-46c8-f597-4e53dc63cb56
242242
243-
key = random.PRNGKey(42)
243+
key = random.key(42)
244244
subkeys = random.split(key, 3)
245245
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
246246
print("individually:", sequence)
247247
248-
key = random.PRNGKey(42)
248+
key = random.key(42)
249249
print("all at once: ", random.normal(key, shape=(3,)))
250250
```
251251

docs/jax-101/06-parallelism.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@
623623
"ys = xs * true_w + true_b + noise\n",
624624
"\n",
625625
"# Initialise parameters and replicate across devices.\n",
626-
"params = init(jax.random.PRNGKey(123))\n",
626+
"params = init(jax.random.key(123))\n",
627627
"n_devices = jax.local_device_count()\n",
628628
"replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)"
629629
]

docs/jax-101/06-parallelism.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ noise = 0.5 * np.random.normal(size=(128, 1))
291291
ys = xs * true_w + true_b + noise
292292
293293
# Initialise parameters and replicate across devices.
294-
params = init(jax.random.PRNGKey(123))
294+
params = init(jax.random.key(123))
295295
n_devices = jax.local_device_count()
296296
replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)
297297
```

docs/jax-101/07-state.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@
249249
"\n",
250250
"In our case, the `CounterV2` class is nothing more than a namespace bringing all the functions that use `CounterState` into one location. Exercise for the reader: do you think it makes sense to keep it as a class?\n",
251251
"\n",
252-
"Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNGKey."
252+
"Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNG key."
253253
]
254254
},
255255
{
@@ -351,7 +351,7 @@
351351
"source": [
352352
"import matplotlib.pyplot as plt\n",
353353
"\n",
354-
"rng = jax.random.PRNGKey(42)\n",
354+
"rng = jax.random.key(42)\n",
355355
"\n",
356356
"# Generate true data from y = w*x + b + noise\n",
357357
"true_w, true_b = 2, -1\n",

docs/jax-101/07-state.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ Notice that the need for a class becomes less clear once we have rewritten it th
166166

167167
In our case, the `CounterV2` class is nothing more than a namespace bringing all the functions that use `CounterState` into one location. Exercise for the reader: do you think it makes sense to keep it as a class?
168168

169-
Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNGKey.
169+
Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNG key.
170170

171171
+++ {"id": "I2SqRx14_z98"}
172172

@@ -233,7 +233,7 @@ Notice that we manually pipe the params in and out of the update function.
233233
234234
import matplotlib.pyplot as plt
235235
236-
rng = jax.random.PRNGKey(42)
236+
rng = jax.random.key(42)
237237
238238
# Generate true data from y = w*x + b + noise
239239
true_w, true_b = 2, -1

docs/jax.nn.initializers.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ consistent with definitions used in Keras and Sonnet.
1414

1515
An initializer is a function that takes three arguments:
1616
``(key, shape, dtype)`` and returns an array with dimensions ``shape`` and
17-
data type ``dtype``. Argument ``key`` is a :class:`jax.random.PRNGKey` random
18-
key used when generating random numbers to initialize the array.
17+
data type ``dtype``. Argument ``key`` is a PRNG key (e.g. from
18+
:func:`jax.random.key`), used to generate random numbers to initialize the array.
1919

2020
.. autosummary::
2121
:toctree: _autosummary

0 commit comments

Comments
 (0)