Skip to content

Commit 342887b

Browse files
author
jax authors
committed
Merge pull request #20306 from jakevdp:jax-101
PiperOrigin-RevId: 617682492
2 parents d0819ae + d6c07bd commit 342887b

14 files changed

+303
-161
lines changed

docs/jep/263-prng.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
(prng-design-jep)=
12
# JAX PRNG Design
23
We want a PRNG design that
34
1. is **expressive** in that it is convenient to use and it doesn’t constrain the user’s ability to write numerical programs with exactly the behavior that they want,

docs/tutorials/advanced-debugging.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,9 @@ kernelspec:
1414

1515
(advanced-debugging)=
1616
# Advanced debugging
17+
```{note}
18+
This is a placeholder for a section in the new {ref}`jax-tutorials`.
19+
20+
For the time being, you may find some related content in the old documentation:
21+
- {doc}`../debugging/index`
22+
```

docs/tutorials/automatic-differentiation.md

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ dfdx = jax.grad(f)
6363
The higher-order derivatives of $f$ are:
6464

6565
$$
66-
\begin{array}{l}s
66+
\begin{array}{l}
6767
f'(x) = 3x^2 + 4x -3\\
6868
f''(x) = 6x + 4\\
6969
f'''(x) = 6\\
@@ -105,27 +105,27 @@ print(d4fdx(1.))
105105
The next example shows how to compute gradients with {func}`jax.grad` in a linear logistic regression model. First, the setup:
106106

107107
```{code-cell}
108-
key = jax.random.PRNGKey(0)
108+
key = jax.random.key(0)
109109
110110
def sigmoid(x):
111-
return 0.5 * (jnp.tanh(x / 2) + 1)
111+
return 0.5 * (jnp.tanh(x / 2) + 1)
112112
113113
# Outputs probability of a label being true.
114114
def predict(W, b, inputs):
115-
return sigmoid(jnp.dot(inputs, W) + b)
115+
return sigmoid(jnp.dot(inputs, W) + b)
116116
117117
# Build a toy dataset.
118118
inputs = jnp.array([[0.52, 1.12, 0.77],
119-
[0.88, -1.08, 0.15],
120-
[0.52, 0.06, -1.30],
121-
[0.74, -2.49, 1.39]])
119+
[0.88, -1.08, 0.15],
120+
[0.52, 0.06, -1.30],
121+
[0.74, -2.49, 1.39]])
122122
targets = jnp.array([True, True, False, True])
123123
124124
# Training loss is the negative log-likelihood of the training examples.
125125
def loss(W, b):
126-
preds = predict(W, b, inputs)
127-
label_probs = preds * targets + (1 - preds) * (1 - targets)
128-
return -jnp.sum(jnp.log(label_probs))
126+
preds = predict(W, b, inputs)
127+
label_probs = preds * targets + (1 - preds) * (1 - targets)
128+
return -jnp.sum(jnp.log(label_probs))
129129
130130
# Initialize random model coefficients
131131
key, W_key, b_key = jax.random.split(key, 3)
@@ -138,20 +138,20 @@ Use the {func}`jax.grad` function with its `argnums` argument to differentiate a
138138
```{code-cell}
139139
# Differentiate `loss` with respect to the first positional argument:
140140
W_grad = grad(loss, argnums=0)(W, b)
141-
print('W_grad', W_grad)
141+
print(f'{W_grad=}')
142142
143143
# Since argnums=0 is the default, this does the same thing:
144144
W_grad = grad(loss)(W, b)
145-
print('W_grad', W_grad)
145+
print(f'{W_grad=}')
146146
147147
# But you can choose different values too, and drop the keyword:
148148
b_grad = grad(loss, 1)(W, b)
149-
print('b_grad', b_grad)
149+
print(f'{b_grad=}')
150150
151151
# Including tuple values
152152
W_grad, b_grad = grad(loss, (0, 1))(W, b)
153-
print('W_grad', W_grad)
154-
print('b_grad', b_grad)
153+
print(f'{W_grad=}')
154+
print(f'{b_grad=}')
155155
```
156156

157157
The {func}`jax.grad` API has a direct correspondence to the excellent notation in Spivak's classic *Calculus on Manifolds* (1965), also used in Sussman and Wisdom's [*Structure and Interpretation of Classical Mechanics*](https://mitpress.mit.edu/9780262028967/structure-and-interpretation-of-classical-mechanics) (2015) and their [*Functional Differential Geometry*](https://mitpress.mit.edu/9780262019347/functional-differential-geometry) (2013). Both books are open-access. See in particular the "Prologue" section of *Functional Differential Geometry* for a defense of this notation.
@@ -162,7 +162,8 @@ Essentially, when using the `argnums` argument, if `f` is a Python function for
162162
(automatic-differentiation-nested-lists-tuples-and-dicts)=
163163
## 3. Differentiating with respect to nested lists, tuples, and dicts
164164

165-
Differentiating with respect to standard Python containers just works, so use tuples, lists, and dicts (and arbitrary nesting) however you like.
165+
Due to JAX's PyTree abstraction (see {ref}`thinking-in-jax-pytrees`), differentiating with
166+
respect to standard Python containers just works, so use tuples, lists, and dicts (and arbitrary nesting) however you like.
166167

167168
Continuing the previous example:
168169

@@ -181,7 +182,7 @@ You can {ref}`pytrees-custom-pytree-nodes` to work with not just {func}`jax.grad
181182
(automatic-differentiation-evaluating-using-jax-value_and_grad)=
182183
## 4. Evaluating a function and its gradient using `jax.value_and_grad`
183184

184-
Another convenient function is {func}`jax.value_and_grad` for efficiently computing both a function's value as well as its gradient's value.
185+
Another convenient function is {func}`jax.value_and_grad` for efficiently computing both a function's value as well as its gradient's value in one pass.
185186

186187
Continuing the previous examples:
187188

docs/tutorials/debugging.md

Lines changed: 72 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -26,44 +26,58 @@ Let's begin with {func}`jax.debug.print`.
2626
- Use {func}`jax.debug.print` for traced (dynamic) array values with {func}`jax.jit`, {func}`jax.vmap` and others.
2727
- Use Python `print` for static values, such as dtypes and array shapes.
2828

29-
With some JAX transformations, such as {func}`jax.grad` and {func}`jax.vmap`, you can use Python’s built-in `print` function to print out numerical values. However, with {func}`jax.jit` for example, you need to use {func}`jax.debug.print`, because those transformations delay numerical evaluation.
30-
31-
Below is a basic example with {func}`jax.jit`:
29+
Recall from {ref}`jit-compilation` that when transforming a function with {func}`jax.jit`,
30+
the Python code is executed with abstract tracers in place of your arrays. Because of this,
31+
the Python `print` statement will only print this tracer value:
3232

3333
```{code-cell}
3434
import jax
3535
import jax.numpy as jnp
3636
3737
@jax.jit
3838
def f(x):
39-
jax.debug.print("This is `jax.debug.print` of x {x}", x=x)
40-
y = jnp.sin(x)
41-
jax.debug.print("This is `jax.debug.print` of y {y} 🤯", y=y)
42-
return y
39+
print("print(x) ->", x)
40+
y = jnp.sin(x)
41+
print("print(y) ->", y)
42+
return y
4343
44-
f(2.)
44+
result = f(2.)
4545
```
4646

47-
{func}`jax.debug.print` can reveal the information about how computations are evaluated.
47+
Python's `print` executes at trace-time, before the runtime values exist.
48+
If you want to print the actual runtime values, you can use {func}`jax.debug.print`:
49+
50+
```{code-cell}
51+
@jax.jit
52+
def f(x):
53+
jax.debug.print("jax.debug.print(x) -> {x}", x=x)
54+
y = jnp.sin(x)
55+
jax.debug.print("jax.debug.print(y) -> {y}", y=y)
56+
return y
57+
58+
result = f(2.)
59+
```
4860

49-
Here's an example with {func}`jax.vmap`:
61+
Similarly, within {func}`jax.vmap`, using Python's `print` will only print the tracer;
62+
to print the values being mapped over, use {func}`jax.debug.print`:
5063

5164
```{code-cell}
5265
def f(x):
53-
jax.debug.print("This is `jax.debug.print` of x: {}", x)
54-
y = jnp.sin(x)
55-
jax.debug.print("This is `jax.debug.print` of y: {}", y)
56-
return y
66+
jax.debug.print("jax.debug.print(x) -> {}", x)
67+
y = jnp.sin(x)
68+
jax.debug.print("jax.debug.print(y) -> {}", y)
69+
return y
5770
5871
xs = jnp.arange(3.)
5972
60-
jax.vmap(f)(xs)
73+
result = jax.vmap(f)(xs)
6174
```
6275

63-
Here's an example with {func}`jax.lax.map`:
76+
Here's the result with {func}`jax.lax.map`, which is a sequential map rather than a
77+
vectorization:
6478

6579
```{code-cell}
66-
jax.lax.map(f, xs)
80+
result = jax.lax.map(f, xs)
6781
```
6882

6983
Notice the order is different, as {func}`jax.vmap` and {func}`jax.lax.map` compute the same results in different ways. When debugging, the evaluation order details are exactly what you may need to inspect.
@@ -72,10 +86,10 @@ Below is an example with {func}`jax.grad`, where {func}`jax.debug.print` only pr
7286

7387
```{code-cell}
7488
def f(x):
75-
jax.debug.print("This is `jax.debug.print` of x: {}", x)
76-
return x ** 2
89+
jax.debug.print("jax.debug.print(x) -> {}", x)
90+
return x ** 2
7791
78-
jax.grad(f)(1.)
92+
result = jax.grad(f)(1.)
7993
```
8094

8195
Sometimes, when the arguments don't depend on one another, calls to {func}`jax.debug.print` may print them in a different order when staged out with a JAX transformation. If you need the original order, such as `x: ...` first and then `y: ...` second, add the `ordered=True` parameter.
@@ -85,9 +99,11 @@ For example:
8599
```{code-cell}
86100
@jax.jit
87101
def f(x, y):
88-
jax.debug.print("This is `jax.debug.print of x: {}", x, ordered=True)
89-
jax.debug.print("This is `jax.debug.print of y: {}", y, ordered=True)
90-
return x + y
102+
jax.debug.print("jax.debug.print(x) -> {}", x, ordered=True)
103+
jax.debug.print("jax.debug.print(y) -> {}", y, ordered=True)
104+
return x + y
105+
106+
f(1, 2)
91107
```
92108

93109
To learn more about {func}`jax.debug.print` and its Sharp Bits, refer to {ref}`advanced-debugging`.
@@ -101,11 +117,24 @@ To pause your compiled JAX program during certain points during debugging, you c
101117

102118
To print all available commands during a `breakpoint` debugging session, use the `help` command. (Full debugger commands, the Sharp Bits, its strengths and limitations are covered in {ref}`advanced-debugging`.)
103119

104-
Example:
120+
Here is an example of what a debugger session might look like:
105121

106122
```{code-cell}
107-
:tags: [raises-exception]
123+
:tags: [skip-execution]
124+
125+
@jax.jit
126+
def f(x):
127+
y, z = jnp.sin(x, jnp.cos(x))
128+
jax.debug.breakpoint()
129+
return y * z
130+
f(2.) # ==> Pauses during execution
131+
```
108132

133+
![JAX debugger](../_static/debugger.gif)
134+
135+
For value-dependent breakpointing, you can use runtime conditionals like {func}`jax.lax.cond`:
136+
137+
```{code-cell}
109138
def breakpoint_if_nonfinite(x):
110139
is_finite = jnp.isfinite(x).all()
111140
def true_fn(x):
@@ -119,20 +148,32 @@ def f(x, y):
119148
z = x / y
120149
breakpoint_if_nonfinite(z)
121150
return z
122-
f(2., 0.) # ==> Pauses during execution
151+
152+
f(2., 1.) # ==> No breakpoint
123153
```
124154

125-
![JAX debugger](../_static/debugger.gif)
155+
```{code-cell}
156+
:tags: [skip-execution]
157+
158+
f(2., 0.) # ==> Pauses during execution
159+
```
126160

127161
## JAX `debug.callback` for more control during debugging
128162

129-
As mentioned in the beginning, {func}`jax.debug.print` is a small wrapper around {func}`jax.debug.callback`. The {func}`jax.debug.callback` method allows you to have greater control over string formatting and the debugging output, like printing or plotting. It is compatible with {func}`jax.jit`, {func}`jax.vmap`, {func}`jax.grad` and other transformations (refer to the {ref}`external-callbacks-flavors-of-callback` table in {ref]`external-callbacks` for more information).
163+
Both {func}`jax.debug.print` and {func}`jax.debug.breakpoint` are implemented using
164+
the more flexible {func}`jax.debug.callback`, which gives greater control over the
165+
host-side logic executed via a Python callback.
166+
It is compatible with {func}`jax.jit`, {func}`jax.vmap`, {func}`jax.grad` and other
167+
transformations (refer to the {ref}`external-callbacks-flavors-of-callback` table in
168+
{ref}`external-callbacks` for more information).
130169

131170
For example:
132171

133172
```{code-cell}
173+
import logging
174+
134175
def log_value(x):
135-
print("log:", x)
176+
logging.warning(f'Logged value: {x}')
136177
137178
@jax.jit
138179
def f(x):
@@ -142,7 +183,7 @@ def f(x):
142183
f(1.0);
143184
```
144185

145-
This callback is compatible with {func}`jax.vmap` and {func}`jax.grad`:
186+
This callback is compatible with other transformations, including {func}`jax.vmap` and {func}`jax.grad`:
146187

147188
```{code-cell}
148189
x = jnp.arange(5.0)
@@ -155,7 +196,7 @@ jax.grad(f)(1.0);
155196

156197
This can make {func}`jax.debug.callback` useful for general-purpose debugging.
157198

158-
You can learn more about different flavors of JAX callbacks in {ref}`external-callbacks-flavors-of-callback` and {ref}`external-callbacks-exploring-debug-callback`.
199+
You can learn more about {func}`jax.debug.callback` and other kinds of JAX callbacks in {ref}`external-callbacks`.
159200

160201
## Next steps
161202

docs/tutorials/external-callbacks.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@ kernelspec:
1212
name: python3
1313
---
1414

15+
```{code-cell}
16+
:tags: [remove-cell]
17+
18+
# This ensures that code cell tracebacks appearing below will be concise.
19+
%xmode minimal
20+
```
21+
1522
(external-callbacks)=
1623
# External callbacks
1724

@@ -117,10 +124,6 @@ jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
117124

118125
However, because there is no way for JAX to introspect the content of the callback, `pure_callback` has undefined autodiff semantics:
119126

120-
```{code-cell}
121-
%xmode minimal
122-
```
123-
124127
```{code-cell}
125128
:tags: [raises-exception]
126129

docs/tutorials/index.rst

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@ JAX 101
1717
.. toctree::
1818
:maxdepth: 1
1919

20-
installation
2120
quickstart
22-
jax-as-accelerated-numpy
2321
thinking-in-jax
2422
jit-compilation
2523
automatic-vectorization
@@ -55,3 +53,12 @@ JAX 301
5553
jax-primitives
5654
jaxpr
5755
advanced-compilation
56+
57+
58+
Reference
59+
---------
60+
61+
.. toctree::
62+
:maxdepth: 1
63+
64+
installation

docs/tutorials/installation.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
(installation)=
2-
# How to install JAX
2+
# Installing JAX
33

44
This guide provides instructions for:
55

@@ -9,11 +9,14 @@ This guide provides instructions for:
99

1010
**TL;DR** For most users, a typical JAX installation may look something like this:
1111

12-
| Hardware | Installation |
13-
|------------------------------------|--------------------------------------------|
14-
| CPU-only, Linux/macOS/Windows | `pip install -U "jax[cpu]"` |
15-
| NVIDIA, CUDA 12, x86_64 | `pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html`|
16-
12+
* **CPU-only (Linux/macOS/Windows)**
13+
```
14+
pip install -U "jax[cpu]"
15+
```
16+
* **GPU (NVIDIA, CUDA 12, x86_64)**
17+
```
18+
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
19+
```
1720

1821
(install-supported-platforms)=
1922
## Supported platforms

docs/tutorials/jax-as-accelerated-numpy.md

Lines changed: 0 additions & 8 deletions
This file was deleted.

docs/tutorials/jit-compilation.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@ kernelspec:
1212
name: python3
1313
---
1414

15+
```{code-cell}
16+
:tags: [remove-cell]
17+
18+
# This ensures that code cell tracebacks appearing below will be concise.
19+
%xmode minimal
20+
```
21+
1522
(jit-compilation)=
1623
# Just-in-time compilation
1724

0 commit comments

Comments
 (0)