Skip to content

Commit 8424824

Browse files
author
jax authors
committed
Merge pull request #20836 from mattjj:quickstart-tweaks
PiperOrigin-RevId: 626472400
2 parents 1671617 + 94e3a6e commit 8424824

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

docs/tutorials/quickstart.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ kernelspec:
1414

1515
# Quickstart
1616

17-
**JAX a library for array-oriented numerical computation (*a la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**.
17+
**JAX a library for array-oriented numerical computation (*à la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**.
1818

1919
This document provides a quick overview of essential JAX features, so you can get started with JAX quickly:
2020

@@ -125,16 +125,18 @@ In the above example we jitted `sum_logistic` and then took its derivative. We c
125125
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
126126
```
127127

128-
Similarly, the {func}`jax.jacobian` transformation can be used to compute gradients of vector-valued functions:
128+
Beyond scalar-valued functions, the {func}`jax.jacobian` transformation can be
129+
used to compute the full Jacobian matrix for vector-valued functions:
129130

130131
```{code-cell}
131132
from jax import jacobian
132133
print(jacobian(jnp.exp)(x_small))
133134
```
134135

135-
For more advanced autodiff operations, you can use {func}`jax..jacrev` for reverse-mode vector-Jacobian products,
136-
and {func}`jax.jacfwd` for forward-mode Jacobian-vector products.
136+
For more advanced autodiff operations, you can use {func}`jax.vjp` for reverse-mode vector-Jacobian products,
137+
and {func}`jax.jvp` and {func}`jax.linearize` for forward-mode Jacobian-vector products.
137138
The two can be composed arbitrarily with one another, and with other JAX transformations.
139+
For example, {func}`jax.jvp` and {func}`jax.vjp` are used to define the forward-mode {func}`jax.jacfwd` and reverse-mode {func}`jax.jacrev` for computing Jacobians in forward- and reverse-mode, respectively.
138140
Here's one way to compose them to make a function that efficiently computes full Hessian matrices:
139141

140142
```{code-cell}

0 commit comments

Comments
 (0)