|
| 1 | +--- |
| 2 | +jupytext: |
| 3 | + formats: md:myst |
| 4 | + text_representation: |
| 5 | + extension: .md |
| 6 | + format_name: myst |
| 7 | + format_version: 0.13 |
| 8 | + jupytext_version: 1.16.0 |
| 9 | +kernelspec: |
| 10 | + display_name: Python 3 |
| 11 | + language: python |
| 12 | + name: python3 |
| 13 | +--- |
| 14 | + |
| 15 | +(key-concepts)= |
| 16 | +# Key Concepts |
| 17 | + |
| 18 | +This section briefly introduces some key concepts of the JAX package. |
| 19 | + |
| 20 | +(key-concepts-jax-arrays)= |
| 21 | +## JAX arrays ({class}`jax.Array`) |
| 22 | + |
| 23 | +- `jax.Array` is the default array implementation in JAX. |
| 24 | +- `jax.Array` objects are never created directly, but rather using familiar |
| 25 | + array creation APIs. |
| 26 | +- JAX arrays may be stored on a single device, or sharded across many devices. |
| 27 | + |
| 28 | +### Array creation |
| 29 | + |
| 30 | +JAX arrays are never constructed directly, but rather are constructed via JAX API functions. |
| 31 | +For example, {mod}`jax.numpy` provides familar NumPy-style array construction functionality |
| 32 | +such as {func}`jax.numpy.zeros`, {func}`jax.numpy.linspace`, {func}`jax.numpy.arange`, etc. |
| 33 | + |
| 34 | +```{code-cell} |
| 35 | +import jax |
| 36 | +import jax.numpy as jnp |
| 37 | +
|
| 38 | +x = jnp.arange(5) |
| 39 | +isinstance(x, jax.Array) |
| 40 | +``` |
| 41 | + |
| 42 | +If you use Python type annotations in your code, {class}`jax.Array` is the appropriate |
| 43 | +annotation for jax array objects (see {mod}`jax.typing` for more discussion). |
| 44 | + |
| 45 | +### Array devices and sharding |
| 46 | + |
| 47 | +JAX Array objects have a `devices` method that lets you inspect where the contents of the array are stored. In the simplest cases, this will be a single CPU device: |
| 48 | + |
| 49 | +```{code-cell} |
| 50 | +x.devices() |
| 51 | +``` |
| 52 | + |
| 53 | +In general, an array may be *sharded* across multiple devices, in a manner that can be inspected via the `sharding` attribute: |
| 54 | + |
| 55 | +```{code-cell} |
| 56 | +x.sharding |
| 57 | +``` |
| 58 | + |
| 59 | +Here the array is on a single device, but in general a JAX array can be |
| 60 | +sharded across multiple devices, or even multiple hosts. |
| 61 | +To read more about sharded arrays and parallel computation, refer to {ref}`single-host-sharding` |
| 62 | + |
| 63 | +(key-concepts-transformations)= |
| 64 | +## Transformations |
| 65 | +Along with functions to operate on arrays, JAX includes a number of |
| 66 | +{term}`transformations <transformation>` which operate on JAX functions. These include |
| 67 | + |
| 68 | +- {func}`jax.jit`: Just-in-time (JIT) compilation; see {ref}`jit-compilation` |
| 69 | +- {func}`jax.vmap`: Vectorizing transform; see {ref}`automatic-vectorization` |
| 70 | +- {func}`jax.grad`: Gradient transform; see {ref}`automatic-differentiation` |
| 71 | + |
| 72 | +as well as several others. Transformations accept a function as an argument, and return a |
| 73 | +new transformed function. For example, here's how you might JIT-compile a simple SELU function: |
| 74 | + |
| 75 | +```{code-cell} |
| 76 | +def selu(x, alpha=1.67, lambda_=1.05): |
| 77 | + return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) |
| 78 | +
|
| 79 | +selu_jit = jax.jit(selu) |
| 80 | +print(selu_jit(1.0)) |
| 81 | +``` |
| 82 | + |
| 83 | +Often you'll see transformations applied using Python's decorator syntax for convenience: |
| 84 | + |
| 85 | +```{code-cell} |
| 86 | +@jax.jit |
| 87 | +def selu(x, alpha=1.67, lambda_=1.05): |
| 88 | + return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) |
| 89 | +``` |
| 90 | + |
| 91 | +Transformations like {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`, and others are |
| 92 | +key to using JAX effectively, and we'll cover them in detail in later sections. |
| 93 | + |
| 94 | +(key-concepts-tracing)= |
| 95 | +## Tracing |
| 96 | + |
| 97 | +The magic behind transformations is the notion of a {term}`Tracers <Tracer>`. |
| 98 | +Tracers are abstract stand-ins for array objects, and are passed to JAX functions in order |
| 99 | +to extract the sequence of operations that the function encodes. |
| 100 | + |
| 101 | +You can see this by printing any array value within transformed JAX code; for example: |
| 102 | + |
| 103 | +```{code-cell} |
| 104 | +@jax.jit |
| 105 | +def f(x): |
| 106 | + print(x) |
| 107 | + return x + 1 |
| 108 | +
|
| 109 | +x = jnp.arange(5) |
| 110 | +result = f(x) |
| 111 | +``` |
| 112 | + |
| 113 | +The value printed is not the array `x`, but a {class}`~jax.core.Tracer` instance that |
| 114 | +represents essential attributes of `x`, such as its `shape` and `dtype`. By executing |
| 115 | +the function with traced values, JAX can determine the sequence of operations encoded |
| 116 | +by the function before those operations are actually executed: transformations like |
| 117 | +{func}`~jax.jit`, {func}`~jax.vmap`, and {func}`~jax.grad` can then map this sequence |
| 118 | +of input operations to a transformed sequence of operations. |
| 119 | + |
| 120 | +(key-concepts-jaxprs)= |
| 121 | +## Jaxprs |
| 122 | + |
| 123 | +JAX has its own intermediate representation for sequences of operations, and these are |
| 124 | +known as {term}`jaxprs <jaxpr>`. A jaxpr (short for *JAX eXPRession*) represents a list |
| 125 | +of core units of computation called {term}`primitives <primitive>` that represent the |
| 126 | +effect of a computation. |
| 127 | + |
| 128 | +For example, consider the `selu` function we defined above: |
| 129 | + |
| 130 | +```{code-cell} |
| 131 | +def selu(x, alpha=1.67, lambda_=1.05): |
| 132 | + return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) |
| 133 | +``` |
| 134 | + |
| 135 | +We can use the {func}`jax.make_jaxpr` utility to convert this function into a jaxpr |
| 136 | +given a particular input: |
| 137 | + |
| 138 | +```{code-cell} |
| 139 | +x = jnp.arange(5.0) |
| 140 | +jax.make_jaxpr(selu)(x) |
| 141 | +``` |
| 142 | + |
| 143 | +Comparing this to the Python function definition, we see that it encodes the precise |
| 144 | +sequence of operations that the function represents. We'll go into more depth about |
| 145 | +jaxprs later in {ref}`jax-internals-jaxpr`. |
| 146 | + |
| 147 | +(key-concepts-pytrees)= |
| 148 | +## Pytrees |
| 149 | + |
| 150 | +JAX functions and transformations fundamentally operate on arrays, but in practice it is |
| 151 | +convenient to write code that work with collections of arrays: for example, a neural |
| 152 | +network might organize its parameters in a dictionary of arrays with meaningful keys. |
| 153 | +Rather than handle such structures on a case-by-case basis, JAX relies on the {term}`pytree` |
| 154 | +abstraction to treat such collections in a uniform matter. |
| 155 | + |
| 156 | +Here are some examples of objects that can be treated as pytrees: |
| 157 | + |
| 158 | +```{code-cell} |
| 159 | +# (nested) list of parameters |
| 160 | +params = [1, 2, (jnp.arange(3), jnp.ones(2))] |
| 161 | +
|
| 162 | +print(jax.tree.structure(params)) |
| 163 | +print(jax.tree.leaves(params)) |
| 164 | +``` |
| 165 | + |
| 166 | +```{code-cell} |
| 167 | +# Dictionary of parameters |
| 168 | +params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)} |
| 169 | +
|
| 170 | +print(jax.tree.structure(params)) |
| 171 | +print(jax.tree.leaves(params)) |
| 172 | +``` |
| 173 | + |
| 174 | +```{code-cell} |
| 175 | +# Named tuple of parameters |
| 176 | +from typing import NamedTuple |
| 177 | +
|
| 178 | +class Params(NamedTuple): |
| 179 | + a: int |
| 180 | + b: float |
| 181 | +
|
| 182 | +params = Params(1, 5.0) |
| 183 | +print(jax.tree.structure(params)) |
| 184 | +print(jax.tree.leaves(params)) |
| 185 | +``` |
| 186 | + |
| 187 | +JAX has a number of general-purpose utilities for working with PyTrees; for example |
| 188 | +the functions {func}`jax.tree.map` can be used to map a function to every leaf in a |
| 189 | +tree, and {func}`jax.tree.reduce` can be used to apply a reduction across the leaves |
| 190 | +in a tree. |
| 191 | + |
| 192 | +You can learn more in the {ref}`working-with-pytrees` tutorial. |
0 commit comments