You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/tutorials/quickstart.md
+6-4Lines changed: 6 additions & 4 deletions
Original file line number
Diff line number
Diff line change
@@ -14,7 +14,7 @@ kernelspec:
14
14
15
15
# Quickstart
16
16
17
-
**JAX a library for array-oriented numerical computation (*a la*[NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**.
17
+
**JAX a library for array-oriented numerical computation (*à la*[NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**.
18
18
19
19
This document provides a quick overview of essential JAX features, so you can get started with JAX quickly:
20
20
@@ -125,16 +125,18 @@ In the above example we jitted `sum_logistic` and then took its derivative. We c
Similarly, the {func}`jax.jacobian` transformation can be used to compute gradients of vector-valued functions:
128
+
Beyond scalar-valued functions, the {func}`jax.jacobian` transformation can be
129
+
used to compute the full Jacobian matrix for vector-valued functions:
129
130
130
131
```{code-cell}
131
132
from jax import jacobian
132
133
print(jacobian(jnp.exp)(x_small))
133
134
```
134
135
135
-
For more advanced autodiff operations, you can use {func}`jax..jacrev` for reverse-mode vector-Jacobian products,
136
-
and {func}`jax.jacfwd` for forward-mode Jacobian-vector products.
136
+
For more advanced autodiff operations, you can use {func}`jax.vjp` for reverse-mode vector-Jacobian products,
137
+
and {func}`jax.jvp` and {func}`jax.linearize` for forward-mode Jacobian-vector products.
137
138
The two can be composed arbitrarily with one another, and with other JAX transformations.
139
+
For example, {func}`jax.jvp` and {func}`jax.vjp` are used to define the forward-mode {func}`jax.jacfwd` and reverse-mode {func}`jax.jacrev` for computing Jacobians in forward- and reverse-mode, respectively.
138
140
Here's one way to compose them to make a function that efficiently computes full Hessian matrices:
0 commit comments