You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/jep/263-prng.md
+1Lines changed: 1 addition & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -1,3 +1,4 @@
1
+
(prng-design-jep)=
1
2
# JAX PRNG Design
2
3
We want a PRNG design that
3
4
1. is **expressive** in that it is convenient to use and it doesn’t constrain the user’s ability to write numerical programs with exactly the behavior that they want,
@@ -138,20 +138,20 @@ Use the {func}`jax.grad` function with its `argnums` argument to differentiate a
138
138
```{code-cell}
139
139
# Differentiate `loss` with respect to the first positional argument:
140
140
W_grad = grad(loss, argnums=0)(W, b)
141
-
print('W_grad', W_grad)
141
+
print(f'{W_grad=}')
142
142
143
143
# Since argnums=0 is the default, this does the same thing:
144
144
W_grad = grad(loss)(W, b)
145
-
print('W_grad', W_grad)
145
+
print(f'{W_grad=}')
146
146
147
147
# But you can choose different values too, and drop the keyword:
148
148
b_grad = grad(loss, 1)(W, b)
149
-
print('b_grad', b_grad)
149
+
print(f'{b_grad=}')
150
150
151
151
# Including tuple values
152
152
W_grad, b_grad = grad(loss, (0, 1))(W, b)
153
-
print('W_grad', W_grad)
154
-
print('b_grad', b_grad)
153
+
print(f'{W_grad=}')
154
+
print(f'{b_grad=}')
155
155
```
156
156
157
157
The {func}`jax.grad` API has a direct correspondence to the excellent notation in Spivak's classic *Calculus on Manifolds* (1965), also used in Sussman and Wisdom's [*Structure and Interpretation of Classical Mechanics*](https://mitpress.mit.edu/9780262028967/structure-and-interpretation-of-classical-mechanics) (2015) and their [*Functional Differential Geometry*](https://mitpress.mit.edu/9780262019347/functional-differential-geometry) (2013). Both books are open-access. See in particular the "Prologue" section of *Functional Differential Geometry* for a defense of this notation.
@@ -162,7 +162,8 @@ Essentially, when using the `argnums` argument, if `f` is a Python function for
## 4. Evaluating a function and its gradient using `jax.value_and_grad`
183
184
184
-
Another convenient function is {func}`jax.value_and_grad` for efficiently computing both a function's value as well as its gradient's value.
185
+
Another convenient function is {func}`jax.value_and_grad` for efficiently computing both a function's value as well as its gradient's value in one pass.
Copy file name to clipboardExpand all lines: docs/tutorials/debugging.md
+72-31Lines changed: 72 additions & 31 deletions
Original file line number
Diff line number
Diff line change
@@ -26,44 +26,58 @@ Let's begin with {func}`jax.debug.print`.
26
26
- Use {func}`jax.debug.print` for traced (dynamic) array values with {func}`jax.jit`, {func}`jax.vmap` and others.
27
27
- Use Python `print` for static values, such as dtypes and array shapes.
28
28
29
-
With some JAX transformations, such as {func}`jax.grad` and {func}`jax.vmap`, you can use Python’s built-in `print`function to print out numerical values. However, with {func}`jax.jit` for example, you need to use {func}`jax.debug.print`, because those transformations delay numerical evaluation.
30
-
31
-
Below is a basic example with {func}`jax.jit`:
29
+
Recall from {ref}`jit-compilation` that when transforming a function with {func}`jax.jit`,
30
+
the Python code is executed with abstract tracers in place of your arrays. Because of this,
31
+
the Python `print` statement will only print this tracer value:
32
32
33
33
```{code-cell}
34
34
import jax
35
35
import jax.numpy as jnp
36
36
37
37
@jax.jit
38
38
def f(x):
39
-
jax.debug.print("This is `jax.debug.print` of x {x}", x=x)
40
-
y = jnp.sin(x)
41
-
jax.debug.print("This is `jax.debug.print` of y {y} 🤯", y=y)
42
-
return y
39
+
print("print(x) ->", x)
40
+
y = jnp.sin(x)
41
+
print("print(y) ->", y)
42
+
return y
43
43
44
-
f(2.)
44
+
result = f(2.)
45
45
```
46
46
47
-
{func}`jax.debug.print` can reveal the information about how computations are evaluated.
47
+
Python's `print` executes at trace-time, before the runtime values exist.
48
+
If you want to print the actual runtime values, you can use {func}`jax.debug.print`:
49
+
50
+
```{code-cell}
51
+
@jax.jit
52
+
def f(x):
53
+
jax.debug.print("jax.debug.print(x) -> {x}", x=x)
54
+
y = jnp.sin(x)
55
+
jax.debug.print("jax.debug.print(y) -> {y}", y=y)
56
+
return y
57
+
58
+
result = f(2.)
59
+
```
48
60
49
-
Here's an example with {func}`jax.vmap`:
61
+
Similarly, within {func}`jax.vmap`, using Python's `print` will only print the tracer;
62
+
to print the values being mapped over, use {func}`jax.debug.print`:
50
63
51
64
```{code-cell}
52
65
def f(x):
53
-
jax.debug.print("This is `jax.debug.print` of x: {}", x)
54
-
y = jnp.sin(x)
55
-
jax.debug.print("This is `jax.debug.print` of y: {}", y)
56
-
return y
66
+
jax.debug.print("jax.debug.print(x) -> {}", x)
67
+
y = jnp.sin(x)
68
+
jax.debug.print("jax.debug.print(y) -> {}", y)
69
+
return y
57
70
58
71
xs = jnp.arange(3.)
59
72
60
-
jax.vmap(f)(xs)
73
+
result = jax.vmap(f)(xs)
61
74
```
62
75
63
-
Here's an example with {func}`jax.lax.map`:
76
+
Here's the result with {func}`jax.lax.map`, which is a sequential map rather than a
77
+
vectorization:
64
78
65
79
```{code-cell}
66
-
jax.lax.map(f, xs)
80
+
result = jax.lax.map(f, xs)
67
81
```
68
82
69
83
Notice the order is different, as {func}`jax.vmap` and {func}`jax.lax.map` compute the same results in different ways. When debugging, the evaluation order details are exactly what you may need to inspect.
@@ -72,10 +86,10 @@ Below is an example with {func}`jax.grad`, where {func}`jax.debug.print` only pr
72
86
73
87
```{code-cell}
74
88
def f(x):
75
-
jax.debug.print("This is `jax.debug.print` of x: {}", x)
76
-
return x ** 2
89
+
jax.debug.print("jax.debug.print(x) -> {}", x)
90
+
return x ** 2
77
91
78
-
jax.grad(f)(1.)
92
+
result = jax.grad(f)(1.)
79
93
```
80
94
81
95
Sometimes, when the arguments don't depend on one another, calls to {func}`jax.debug.print` may print them in a different order when staged out with a JAX transformation. If you need the original order, such as `x: ...` first and then `y: ...` second, add the `ordered=True` parameter.
@@ -85,9 +99,11 @@ For example:
85
99
```{code-cell}
86
100
@jax.jit
87
101
def f(x, y):
88
-
jax.debug.print("This is `jax.debug.print of x: {}", x, ordered=True)
89
-
jax.debug.print("This is `jax.debug.print of y: {}", y, ordered=True)
jax.debug.print("jax.debug.print(y) -> {}", y, ordered=True)
104
+
return x + y
105
+
106
+
f(1, 2)
91
107
```
92
108
93
109
To learn more about {func}`jax.debug.print` and its Sharp Bits, refer to {ref}`advanced-debugging`.
@@ -101,11 +117,24 @@ To pause your compiled JAX program during certain points during debugging, you c
101
117
102
118
To print all available commands during a `breakpoint` debugging session, use the `help` command. (Full debugger commands, the Sharp Bits, its strengths and limitations are covered in {ref}`advanced-debugging`.)
103
119
104
-
Example:
120
+
Here is an example of what a debugger session might look like:
105
121
106
122
```{code-cell}
107
-
:tags: [raises-exception]
123
+
:tags: [skip-execution]
124
+
125
+
@jax.jit
126
+
def f(x):
127
+
y, z = jnp.sin(x, jnp.cos(x))
128
+
jax.debug.breakpoint()
129
+
return y * z
130
+
f(2.) # ==> Pauses during execution
131
+
```
108
132
133
+

134
+
135
+
For value-dependent breakpointing, you can use runtime conditionals like {func}`jax.lax.cond`:
136
+
137
+
```{code-cell}
109
138
def breakpoint_if_nonfinite(x):
110
139
is_finite = jnp.isfinite(x).all()
111
140
def true_fn(x):
@@ -119,20 +148,32 @@ def f(x, y):
119
148
z = x / y
120
149
breakpoint_if_nonfinite(z)
121
150
return z
122
-
f(2., 0.) # ==> Pauses during execution
151
+
152
+
f(2., 1.) # ==> No breakpoint
123
153
```
124
154
125
-

155
+
```{code-cell}
156
+
:tags: [skip-execution]
157
+
158
+
f(2., 0.) # ==> Pauses during execution
159
+
```
126
160
127
161
## JAX `debug.callback` for more control during debugging
128
162
129
-
As mentioned in the beginning, {func}`jax.debug.print` is a small wrapper around {func}`jax.debug.callback`. The {func}`jax.debug.callback` method allows you to have greater control over string formatting and the debugging output, like printing or plotting. It is compatible with {func}`jax.jit`, {func}`jax.vmap`, {func}`jax.grad` and other transformations (refer to the {ref}`external-callbacks-flavors-of-callback` table in {ref]`external-callbacks` for more information).
163
+
Both {func}`jax.debug.print` and {func}`jax.debug.breakpoint` are implemented using
164
+
the more flexible {func}`jax.debug.callback`, which gives greater control over the
165
+
host-side logic executed via a Python callback.
166
+
It is compatible with {func}`jax.jit`, {func}`jax.vmap`, {func}`jax.grad` and other
167
+
transformations (refer to the {ref}`external-callbacks-flavors-of-callback` table in
168
+
{ref}`external-callbacks` for more information).
130
169
131
170
For example:
132
171
133
172
```{code-cell}
173
+
import logging
174
+
134
175
def log_value(x):
135
-
print("log:", x)
176
+
logging.warning(f'Logged value: {x}')
136
177
137
178
@jax.jit
138
179
def f(x):
@@ -142,7 +183,7 @@ def f(x):
142
183
f(1.0);
143
184
```
144
185
145
-
This callback is compatible with {func}`jax.vmap` and {func}`jax.grad`:
186
+
This callback is compatible with other transformations, including {func}`jax.vmap` and {func}`jax.grad`:
146
187
147
188
```{code-cell}
148
189
x = jnp.arange(5.0)
@@ -155,7 +196,7 @@ jax.grad(f)(1.0);
155
196
156
197
This can make {func}`jax.debug.callback` useful for general-purpose debugging.
157
198
158
-
You can learn more about different flavors of JAX callbacks in {ref}`external-callbacks-flavors-of-callback` and {ref}`external-callbacks-exploring-debug-callback`.
199
+
You can learn more about {func}`jax.debug.callback` and other kinds of JAX callbacks in {ref}`external-callbacks`.
0 commit comments