Skip to content

Commit d1235fa

Browse files
committed
DOC: add key concepts doc
This will replace the new content in thinking-in-jax within the new tutorial flow.
1 parent 342887b commit d1235fa

File tree

3 files changed

+208
-5
lines changed

3 files changed

+208
-5
lines changed

docs/glossary.rst

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ JAX Glossary of Terms
33

44
.. glossary::
55

6+
Array
7+
JAX's analog of :class:`numpy.ndarray`. See :class:`jax.Array`.
8+
69
CPU
710
Short for *Central Processing Unit*, CPUs are the standard computational architecture
811
available in most computers. JAX can run computations on CPUs, but often can achieve
@@ -12,9 +15,6 @@ JAX Glossary of Terms
1215
The generic name used to refer to the :term:`CPU`, :term:`GPU`, or :term:`TPU` used
1316
by JAX to perform computations.
1417

15-
DeviceArray
16-
JAX's analog of the :class:`numpy.ndarray`. See :class:`jaxlib.xla_extension.DeviceArray`.
17-
1818
forward-mode autodiff
1919
See :term:`JVP`
2020

@@ -30,7 +30,7 @@ JAX Glossary of Terms
3030
jaxpr
3131
Short for *JAX Expression*, a jaxpr is an intermediate representation of a computation that
3232
is generated by JAX, and is forwarded to :term:`XLA` for compilation and execution.
33-
See :ref:`understanding-jaxprs` for more information.
33+
See :ref:`understanding-jaxprs` for more discussion and examples.
3434

3535
JIT
3636
Short for *Just In Time* compilation, JIT in JAX generally refers to the compilation of
@@ -41,11 +41,21 @@ JAX Glossary of Terms
4141
differentiation. For more details, see :ref:`jacobian-vector-product`. In JAX, JVP is
4242
a :term:`transformation` that is implemented via :func:`jax.jvp`. See also :term:`VJP`.
4343

44+
primitive
45+
A primitive is a fundamental unit of computation used in JAX programs. Most functions
46+
in :mod:`jax.lax` represent individual primitives. When representing a computation in
47+
a :term:`jaxpr`, each operation in the jaxpr is a primitive.
48+
4449
pure function
4550
A pure function is a function whose outputs are based only on its inputs, and which has
4651
no side-effects. JAX's :term:`transformation` model is designed to work with pure functions.
4752
See also :term:`functional programming`.
4853

54+
pytree
55+
A pytree is an abstraction that lets JAX handle tuples, lists, dicts, and other more
56+
general containers of array values in a uniform way. Refer to {ref}`working-with-pytrees`
57+
for a more detailed discussion.
58+
4959
reverse-mode autodiff
5060
See :term:`VJP`.
5161

@@ -65,7 +75,7 @@ JAX Glossary of Terms
6575
fast operations on arrays (see also :term:`CPU` and :term:`GPU`).
6676

6777
Tracer
68-
An object used as a standin for a JAX :term:`DeviceArray` in order to determine the
78+
An object used as a standin for a JAX :term:`Array` in order to determine the
6979
sequence of operations performed by a Python function. Internally, JAX implements this
7080
via the :class:`jax.core.Tracer` class.
7181

docs/tutorials/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ JAX 101
1818
:maxdepth: 1
1919

2020
quickstart
21+
key-concepts
2122
thinking-in-jax
2223
jit-compilation
2324
automatic-vectorization

docs/tutorials/key-concepts.md

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
---
2+
jupytext:
3+
formats: md:myst
4+
text_representation:
5+
extension: .md
6+
format_name: myst
7+
format_version: 0.13
8+
jupytext_version: 1.16.0
9+
kernelspec:
10+
display_name: Python 3
11+
language: python
12+
name: python3
13+
---
14+
15+
(key-concepts)=
16+
# Key Concepts
17+
18+
This section briefly introduces some key concepts of the JAX package.
19+
20+
(key-concepts-jax-arrays)=
21+
## JAX arrays ({class}`jax.Array`)
22+
23+
- `jax.Array` is the default array implementation in JAX.
24+
- `jax.Array` objects are never created directly, but rather using familiar
25+
array creation APIs.
26+
- JAX arrays may be stored on a single device, or sharded across many devices.
27+
28+
### Array creation
29+
30+
JAX arrays are never constructed directly, but rather are constructed via JAX API functions.
31+
For example, {mod}`jax.numpy` provides familar NumPy-style array construction functionality
32+
such as {func}`jax.numpy.zeros`, {func}`jax.numpy.linspace`, {func}`jax.numpy.arange`, etc.
33+
34+
```{code-cell}
35+
import jax
36+
import jax.numpy as jnp
37+
38+
x = jnp.arange(5)
39+
isinstance(x, jax.Array)
40+
```
41+
42+
If you use Python type annotations in your code, {class}`jax.Array` is the appropriate
43+
annotation for jax array objects (see {mod}`jax.typing` for more discussion).
44+
45+
### Array devices and sharding
46+
47+
JAX Array objects have a `devices` method that lets you inspect where the contents of the array are stored. In the simplest cases, this will be a single CPU device:
48+
49+
```{code-cell}
50+
x.devices()
51+
```
52+
53+
In general, an array may be *sharded* across multiple devices, in a manner that can be inspected via the `sharding` attribute:
54+
55+
```{code-cell}
56+
x.sharding
57+
```
58+
59+
Here the array is on a single device, but in general a JAX array can be
60+
sharded across multiple devices, or even multiple hosts.
61+
To read more about sharded arrays and parallel computation, refer to {ref}`single-host-sharding`
62+
63+
(key-concepts-transformations)=
64+
## Transformations
65+
Along with functions to operate on arrays, JAX includes a number of
66+
{term}`transformations <transformation>` which operate on JAX functions. These include
67+
68+
- {func}`jax.jit`: Just-in-time (JIT) compilation; see {ref}`jit-compilation`
69+
- {func}`jax.vmap`: Vectorizing transform; see {ref}`automatic-vectorization`
70+
- {func}`jax.grad`: Gradient transform; see {ref}`automatic-differentiation`
71+
72+
as well as several others. Transformations accept a function as an argument, and return a
73+
new transformed function. For example, here's how you might JIT-compile a simple SELU function:
74+
75+
```{code-cell}
76+
def selu(x, alpha=1.67, lambda_=1.05):
77+
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
78+
79+
selu_jit = jax.jit(selu)
80+
print(selu_jit(1.0))
81+
```
82+
83+
Often you'll see transformations applied using Python's decorator syntax for convenience:
84+
85+
```{code-cell}
86+
@jax.jit
87+
def selu(x, alpha=1.67, lambda_=1.05):
88+
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
89+
```
90+
91+
Transformations like {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`, and others are
92+
key to using JAX effectively, and we'll cover them in detail in later sections.
93+
94+
(key-concepts-tracing)=
95+
## Tracing
96+
97+
The magic behind transformations is the notion of a {term}`Tracers <Tracer>`.
98+
Tracers are abstract stand-ins for array objects, and are passed to JAX functions in order
99+
to extract the sequence of operations that the function encodes.
100+
101+
You can see this by printing any array value within transformed JAX code; for example:
102+
103+
```{code-cell}
104+
@jax.jit
105+
def f(x):
106+
print(x)
107+
return x + 1
108+
109+
x = jnp.arange(5)
110+
result = f(x)
111+
```
112+
113+
The value printed is not the array `x`, but a {class}`~jax.core.Tracer` instance that
114+
represents essential attributes of `x`, such as its `shape` and `dtype`. By executing
115+
the function with traced values, JAX can determine the sequence of operations encoded
116+
by the function before those operations are actually executed: transformations like
117+
{func}`~jax.jit`, {func}`~jax.vmap`, and {func}`~jax.grad` can then map this sequence
118+
of input operations to a transformed sequence of operations.
119+
120+
(key-concepts-jaxprs)=
121+
## Jaxprs
122+
123+
JAX has its own intermediate representation for sequences of operations, and these are
124+
known as {term}`jaxprs <jaxpr>`. A jaxpr (short for *JAX eXPRession*) represents a list
125+
of core units of computation called {term}`primitives <primitive>` that represent the
126+
effect of a computation.
127+
128+
For example, consider the `selu` function we defined above:
129+
130+
```{code-cell}
131+
def selu(x, alpha=1.67, lambda_=1.05):
132+
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
133+
```
134+
135+
We can use the {func}`jax.make_jaxpr` utility to convert this function into a jaxpr
136+
given a particular input:
137+
138+
```{code-cell}
139+
x = jnp.arange(5.0)
140+
jax.make_jaxpr(selu)(x)
141+
```
142+
143+
Comparing this to the Python function definition, we see that it encodes the precise
144+
sequence of operations that the function represents. We'll go into more depth about
145+
jaxprs later in {ref}`jax-internals-jaxpr`.
146+
147+
(key-concepts-pytrees)=
148+
## Pytrees
149+
150+
JAX functions and transformations fundamentally operate on arrays, but in practice it is
151+
convenient to write code that work with collections of arrays: for example, a neural
152+
network might organize its parameters in a dictionary of arrays with meaningful keys.
153+
Rather than handle such structures on a case-by-case basis, JAX relies on the {term}`pytree`
154+
abstraction to treat such collections in a uniform matter.
155+
156+
Here are some examples of objects that can be treated as pytrees:
157+
158+
```{code-cell}
159+
# (nested) list of parameters
160+
params = [1, 2, (jnp.arange(3), jnp.ones(2))]
161+
162+
print(jax.tree.structure(params))
163+
print(jax.tree.leaves(params))
164+
```
165+
166+
```{code-cell}
167+
# Dictionary of parameters
168+
params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}
169+
170+
print(jax.tree.structure(params))
171+
print(jax.tree.leaves(params))
172+
```
173+
174+
```{code-cell}
175+
# Named tuple of parameters
176+
from typing import NamedTuple
177+
178+
class Params(NamedTuple):
179+
a: int
180+
b: float
181+
182+
params = Params(1, 5.0)
183+
print(jax.tree.structure(params))
184+
print(jax.tree.leaves(params))
185+
```
186+
187+
JAX has a number of general-purpose utilities for working with PyTrees; for example
188+
the functions {func}`jax.tree.map` can be used to map a function to every leaf in a
189+
tree, and {func}`jax.tree.reduce` can be used to apply a reduction across the leaves
190+
in a tree.
191+
192+
You can learn more in the {ref}`working-with-pytrees` tutorial.

0 commit comments

Comments
 (0)