You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/tutorials/advanced-autodiff.md
+1-1Lines changed: 1 addition & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -13,7 +13,7 @@ kernelspec:
13
13
---
14
14
15
15
(advanced-autodiff)=
16
-
# Advanced automatic differentiation 201
16
+
# Advanced automatic differentiation
17
17
18
18
In this tutorial, you will learn about complex applications of automatic differentiation (autodiff) in JAX and gain a better understanding of how taking derivatives in JAX can be both easy and powerful.docs.g
Copy file name to clipboardExpand all lines: docs/tutorials/automatic-differentiation.md
+6-6Lines changed: 6 additions & 6 deletions
Original file line number
Diff line number
Diff line change
@@ -13,9 +13,9 @@ kernelspec:
13
13
---
14
14
15
15
(automatic-differentiation)=
16
-
# Automatic differentiation 101
16
+
# Automatic differentiation
17
17
18
-
In this tutorial, you will learn about fundamental applications of automatic differentiation (autodiff) in JAX. JAX has a pretty general automatic differentiation (autodiff) system. Computing gradients is a critical part of modern machine learning methods, and this tutorial will walk you through a few introductory autodiff topics, such as:
18
+
In this section, you will learn about fundamental applications of automatic differentiation (autodiff) in JAX. JAX has a pretty general automatic differentiation (autodiff) system. Computing gradients is a critical part of modern machine learning methods, and this tutorial will walk you through a few introductory autodiff topics, such as:
@@ -28,9 +28,9 @@ Make sure to also check out the {ref}`advanced-autodiff` tutorial for more advan
28
28
While understanding how automatic differentiation works "under the hood" isn't crucial for using JAX in most contexts, you are encouraged to check out this quite accessible [video](https://www.youtube.com/watch?v=wG_nF1awSSY) to get a deeper sense of what's going on.
29
29
30
30
(automatic-differentiation-taking-gradients)=
31
-
## 1.Taking gradients with `jax.grad`
31
+
## 1.Taking gradients with `jax.grad`
32
32
33
-
In JAX, you can differentiate a function with the {func}`jax.grad` transformation:
33
+
In JAX, you can differentiate a scalar-valued function with the {func}`jax.grad` transformation:
34
34
35
35
```{code-cell}
36
36
import jax
@@ -162,7 +162,7 @@ Essentially, when using the `argnums` argument, if `f` is a Python function for
## 3. Differentiating with respect to nested lists, tuples, and dicts
164
164
165
-
Due to JAX's PyTree abstraction (see {ref}`pytrees-what-is-a-pytree`), differentiating with
165
+
Due to JAX's PyTree abstraction (see {ref}`working-with-pytrees`), differentiating with
166
166
respect to standard Python containers just works, so use tuples, lists, and dicts (and arbitrary nesting) however you like.
167
167
168
168
Continuing the previous example:
@@ -176,7 +176,7 @@ def loss2(params_dict):
176
176
print(grad(loss2)({'W': W, 'b': b}))
177
177
```
178
178
179
-
You can {ref}`pytrees-custom-pytree-nodes` to work with not just {func}`jax.grad` but other JAX transformations ({func}`jax.jit`, {func}`jax.vmap`, and so on).
179
+
You can create {ref}`pytrees-custom-pytree-nodes` to work with not just {func}`jax.grad` but other JAX transformations ({func}`jax.jit`, {func}`jax.vmap`, and so on).
Copy file name to clipboardExpand all lines: docs/tutorials/automatic-vectorization.md
+6-5Lines changed: 6 additions & 5 deletions
Original file line number
Diff line number
Diff line change
@@ -13,11 +13,12 @@ kernelspec:
13
13
---
14
14
15
15
(automatic-vectorization)=
16
-
# Automatic Vectorization in JAX
16
+
# Automatic vectorization
17
17
18
-
In the previous section we discussed JIT compilation via the `jax.jit` function. This notebook discusses another of JAX's transforms: vectorization via `jax.vmap`.
18
+
In the previous section we discussed JIT compilation via the {func}`jax.jit` function.
19
+
This notebook discusses another of JAX's transforms: vectorization via {func}`jax.vmap`.
19
20
20
-
## Manual Vectorization
21
+
## Manual vectorization
21
22
22
23
Consider the following simple code that computes the convolution of two one-dimensional vectors:
Copy file name to clipboardExpand all lines: docs/tutorials/debugging.md
+6-6Lines changed: 6 additions & 6 deletions
Original file line number
Diff line number
Diff line change
@@ -13,22 +13,22 @@ kernelspec:
13
13
---
14
14
15
15
(debugging)=
16
-
# Debugging 101
16
+
# Introduction to debugging
17
17
18
-
This tutorial introduces you to a set of built-in JAX debugging methods — {func}`jax.debug.print`, {func}`jax.debug.breakpoint`, and {func}`jax.debug.callback` — that you can use with various JAX transformations.
18
+
This section introduces you to a set of built-in JAX debugging methods — {func}`jax.debug.print`, {func}`jax.debug.breakpoint`, and {func}`jax.debug.callback` — that you can use with various JAX transformations.
19
19
20
20
Let's begin with {func}`jax.debug.print`.
21
21
22
-
## JAX `debug.print` for high-level debugging
22
+
## JAX `debug.print` for high-level
23
23
24
24
**TL;DR** Here is a rule of thumb:
25
25
26
26
- Use {func}`jax.debug.print` for traced (dynamic) array values with {func}`jax.jit`, {func}`jax.vmap` and others.
27
-
- Use Python `print` for static values, such as dtypes and array shapes.
27
+
- Use Python {func}`print` for static values, such as dtypes and array shapes.
28
28
29
29
Recall from {ref}`jit-compilation` that when transforming a function with {func}`jax.jit`,
30
30
the Python code is executed with abstract tracers in place of your arrays. Because of this,
31
-
the Python `print`statement will only print this tracer value:
31
+
the Python {func}`print`function will only print this tracer value:
32
32
33
33
```{code-cell}
34
34
import jax
@@ -82,7 +82,7 @@ result = jax.lax.map(f, xs)
82
82
83
83
Notice the order is different, as {func}`jax.vmap` and {func}`jax.lax.map` compute the same results in different ways. When debugging, the evaluation order details are exactly what you may need to inspect.
84
84
85
-
Below is an example with {func}`jax.grad`, where {func}`jax.debug.print` only prints the forward pass. In this case, the behavior is similar to Python's `print`, but it's consistent if you apply {func}`jax.jit` during the call.
85
+
Below is an example with {func}`jax.grad`, where {func}`jax.debug.print` only prints the forward pass. In this case, the behavior is similar to Python's {func}`print`, but it's consistent if you apply {func}`jax.jit` during the call.
Copy file name to clipboardExpand all lines: docs/tutorials/jaxpr.md
+1-1Lines changed: 1 addition & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -13,7 +13,7 @@ kernelspec:
13
13
---
14
14
15
15
(jax-internals-jaxpr)=
16
-
# JAX internals 301: The jaxpr language
16
+
# JAX internals: The jaxpr language
17
17
18
18
Jaxprs are JAX’s internal intermediate representation (IR) of programs. They are explicitly typed, functional, first-order, and in algebraic normal form (ANF).
The {ref}`understanding-jaxprs` section of the documentation provides more information on the meaning of the above output.
53
53
54
-
Importantly, note how the jaxpr does not capture the side-effect of the function: there is nothing in it corresponding to `global_list.append(x)`. This is a feature, not a bug: JAX is designed to understand side-effect-free (a.k.a. functionally pure) code. If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions).
54
+
Importantly, notice that the jaxpr does not capture the side-effect present in the function: there is nothing in it corresponding to `global_list.append(x)`.
55
+
This is a feature, not a bug: JAX transformations are designed to understand side-effect-free (a.k.a. functionally pure) code.
56
+
If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions).
55
57
56
-
Of course, impure functions can still be written and even run, but JAX gives no guarantees about their behaviour once converted to jaxpr. However, as a rule of thumb, you can expect (but shouldn't rely on) the side-effects of a JAX-transformed function to run once (during the first call), and never again. This is because of the way that JAX generates jaxpr, using a process called 'tracing'.
58
+
Of course, impure functions can still be written and even run, but JAX gives no guarantees about their behaviour under transformations. However, as a rule of thumb, you can expect (but shouldn't rely on) the side-effects of a JIT-compiled function to run once (during the first call), and never again, due to JAX's traced execution model.
57
59
58
60
When tracing, JAX wraps each argument by a *tracer* object. These tracers then record all JAX operations performed on them during the function call (which happens in regular Python). Then, JAX uses the tracer records to reconstruct the entire function. The output of that reconstruction is the jaxpr. Since the tracers do not record the Python side-effects, they do not appear in the jaxpr. However, the side-effects still happen during the trace itself.
59
61
@@ -73,7 +75,8 @@ See how the printed `x` is a `Traced` object? That's the JAX internals at work.
73
75
74
76
The fact that the Python code runs at least once is strictly an implementation detail, and so shouldn't be relied upon. However, it's useful to understand as you can use it when debugging to print out intermediate values of a computation.
75
77
76
-
A key thing to understand is that jaxpr captures the function as executed on the parameters given to it. For example, if we have a conditional, jaxpr will only know about the branch we take:
78
+
A key thing to understand is that a jaxpr captures the function as executed on the parameters given to it.
79
+
For example, if we have a Python conditional, the jaxpr will only know about the branch we take:
77
80
78
81
```{code-cell}
79
82
def log2_if_rank_2(x):
@@ -143,8 +146,7 @@ def f(x):
143
146
else:
144
147
return 2 * x
145
148
146
-
f_jit = jax.jit(f)
147
-
f_jit(10) # Should raise an error.
149
+
jax.jit(f)(10) # Raises an error
148
150
```
149
151
150
152
```{code-cell}
@@ -158,19 +160,17 @@ def g(x, n):
158
160
i += 1
159
161
return x + i
160
162
161
-
g_jit = jax.jit(g)
162
-
g_jit(10, 20) # Should raise an error.
163
+
jax.jit(g)(10, 20) # Raises an error
163
164
```
164
165
165
-
The problem is that we tried to condition on the *value* of an input to the function being jitted. The reason we can't do this is related to the fact mentioned above that jaxpr depends on the actual values used to trace it.
166
+
The problem in both cases is that we tried to condition the trace-time flow of the program using runtime values.
167
+
Traced values within JIT, like `x` and `n` here, can only affect control flow via their static attributes: such as
168
+
`shape` or `dtype`, and not via their values.
169
+
For more detail on the interaction between Python control flow and JAX, see [🔪 JAX - The Sharp Bits 🔪: Control Flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow).
166
170
167
-
The more specific information about the values we use in the trace, the more we can use standard Python control flow to express ourselves. However, being too specific means we can't reuse the same traced function for other values. JAX solves this by tracing at different levels of abstraction for different purposes.
168
-
169
-
For {func}`jax.jit`, the default level is {class}`~jax.core.ShapedArray` -- that is, each tracer has a concrete shape (which we're allowed to condition on), but no concrete value. This allows the compiled function to work on all possible inputs with the same shape -- the standard use case in machine learning. However, because the tracers have no concrete value, if we attempt to condition on one, we get the error above.
170
-
171
-
In {func}`jax.grad`, the constraints are more relaxed, so you can do more. If you compose several transformations, however, you must satisfy the constraints of the most strict one. So, if you `jit(grad(f))`, `f` mustn't condition on value. For more detail on the interaction between Python control flow and JAX, see [🔪 JAX - The Sharp Bits 🔪: Control Flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow).
172
-
173
-
One way to deal with this problem is to rewrite the code to avoid conditionals on value. Another is to use special {ref}`lax-control-flow` like {func}`jax.lax.cond`. However, sometimes that is impossible. In that case, you can consider jitting only part of the function. For example, if the most computationally expensive part of the function is inside the loop, we can JIT just that inner part (though make sure to check the next section on caching to avoid shooting yourself in the foot):
171
+
One way to deal with this problem is to rewrite the code to avoid conditionals on value. Another is to use special {ref}`lax-control-flow` like {func}`jax.lax.cond`. However, sometimes that is not possible or practical.
172
+
In that case, you can consider JIT-compiling only part of the function.
173
+
For example, if the most computationally expensive part of the function is inside the loop, we can JIT-compile just that inner part (though make sure to check the next section on caching to avoid shooting yourself in the foot):
174
174
175
175
```{code-cell}
176
176
# While loop conditioned on x and n with a jitted body.
@@ -188,7 +188,11 @@ def g_inner_jitted(x, n):
188
188
g_inner_jitted(10, 20)
189
189
```
190
190
191
-
If we really need to JIT a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums` or `static_argnames`. The cost of this is that the resulting jaxpr is less flexible, so JAX will have to re-compile the function for every new value of the specified static input. It is only a good strategy if the function is guaranteed to get limited different values.
191
+
## Marking arguments as static
192
+
193
+
If we really need to JIT-compile a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums` or `static_argnames`.
194
+
The cost of this is that the resulting jaxpr and compiled artifact depends on the particular value passed, and so JAX will have to re-compile the function for every new value of the specified static input.
195
+
It is only a good strategy if the function is guaranteed to see a limited set of static values.
192
196
193
197
```{code-cell}
194
198
f_jit_correct = jax.jit(f, static_argnums=0)
@@ -227,9 +231,10 @@ print("g:")
227
231
%timeit g(10, 20)
228
232
```
229
233
230
-
This is because {func}`jax.jit` introduces some overhead itself. Therefore, it usually only saves time if the compiled function is complex and you will run it numerous times. Fortunately, this is common in machine learning, where we tend to compile a large, complicated model, then run it for millions of iterations.
234
+
This is because {func}`jax.jit` introduces some overhead itself, and so it usually only saves time if the compiled function is nontrivial, or if you will run it numerous times.
235
+
Fortunately, this is common in machine learning, where we tend to compile a large, complicated model, then run it for millions of iterations.
231
236
232
-
Generally, you want to jit the largest possible chunk of your computation; ideally, the entire update step. This gives the compiler maximum freedom to optimise.
237
+
Generally, you want to JIT-compile the largest possible chunk of your computation; ideally, the entire update step. This gives the compiler maximum freedom to optimise.
0 commit comments