Skip to content

Commit c2b1633

Browse files
committed
Document all checkpoint policies in one place, on the JAX public API page.
1 parent ff0a516 commit c2b1633

File tree

5 files changed

+38
-45
lines changed

5 files changed

+38
-45
lines changed

docs/gradient-checkpointing.md

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -411,22 +411,7 @@ The code defines a function `f` that which applies checkpointing with a custom p
411411

412412
#### List of policies
413413

414-
The policies are:
415-
* `everything_saveable` (the default strategy, as if `jax.checkpoint` were not being used at all)
416-
* `nothing_saveable` (i.e. rematerialize everything, as if a custom policy were not being used at all)
417-
* `dots_saveable` or its alias `checkpoint_dots`
418-
* `dots_with_no_batch_dims_saveable` or its alias `checkpoint_dots_with_no_batch_dims`
419-
* `save_anything_but_these_names` (save any values except for the output of
420-
`checkpoint_name` with any of the names given)
421-
* `save_any_names_but_these` (save only named values, i.e. any outputs of
422-
`checkpoint_name`, except for those with the names given)
423-
* `save_only_these_names` (save only named values, and only among the names
424-
given)
425-
* `offload_dot_with_no_batch_dims` same as `dots_with_no_batch_dims_saveable`,
426-
but offload to CPU memory instead of recomputing.
427-
* `save_and_offload_only_these_names` same as `save_only_these_names`, but
428-
offload to CPU memory instead of recomputing.
429-
* `save_from_both_policies(policy_1, policy_2)` (like a logical `or`, so that a residual is saveable if it is saveable according to `policy_1` _or_ `policy_2`)
414+
The policies can be found [here](https://docs.jax.dev/en/latest/jax.html#checkpoint-policies).
430415

431416
Policies only indicate what is saveable; a value is only saved if it's actually needed by the backward pass.
432417

docs/jax.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,21 @@ Miscellaneous
263263
print_environment_info
264264
live_arrays
265265
clear_caches
266+
267+
Checkpoint policies
268+
-------------------
269+
270+
.. autosummary::
271+
:toctree: _autosummary
272+
273+
checkpoint_policies.everything_saveable
274+
checkpoint_policies.nothing_saveable
275+
checkpoint_policies.dots_saveable
276+
checkpoint_policies.checkpoint_dots
277+
checkpoint_policies.dots_with_no_batch_dims_saveable
278+
checkpoint_policies.checkpoint_dots_with_no_batch_dims
279+
checkpoint_policies.save_any_names_but_these
280+
checkpoint_policies.save_only_these_names
281+
checkpoint_policies.offload_dot_with_no_batch_dims
282+
checkpoint_policies.save_and_offload_only_these_names
283+
checkpoint_policies.save_from_both_policies

docs/notebooks/autodiff_remat.ipynb

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -812,17 +812,7 @@
812812
"source": [
813813
"Another policy which refers to names is `jax.checkpoint_policies.save_only_these_names`.\n",
814814
"\n",
815-
"Some of the policies are:\n",
816-
"* `everything_saveable` (the default strategy, as if `jax.checkpoint` were not being used at all)\n",
817-
"* `nothing_saveable` (i.e. rematerialize everything, as if a custom policy were not being used at all)\n",
818-
"* `dots_saveable` or its alias `checkpoint_dots`\n",
819-
"* `dots_with_no_batch_dims_saveable` or its alias `checkpoint_dots_with_no_batch_dims`\n",
820-
"* `save_anything_but_these_names` (save any values except for the output of\n",
821-
" `checkpoint_name` with any of the names given)\n",
822-
"* `save_any_names_but_these` (save only named values, i.e. any outputs of\n",
823-
" `checkpoint_name`, except for those with the names given)\n",
824-
"* `save_only_these_names` (save only named values, and only among the names\n",
825-
" given)\n",
815+
"A list of policies can be found [here](https://docs.jax.dev/en/latest/jax.html#checkpoint-policies).\n",
826816
"\n",
827817
"Policies only indicate what is saveable; a value is only saved if it's actually needed by the backward pass."
828818
]

docs/notebooks/autodiff_remat.md

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -396,18 +396,6 @@ print_saved_residuals(loss_checkpoint2, params, x, y)
396396

397397
Another policy which refers to names is `jax.checkpoint_policies.save_only_these_names`.
398398

399-
Some of the policies are:
400-
* `everything_saveable` (the default strategy, as if `jax.checkpoint` were not being used at all)
401-
* `nothing_saveable` (i.e. rematerialize everything, as if a custom policy were not being used at all)
402-
* `dots_saveable` or its alias `checkpoint_dots`
403-
* `dots_with_no_batch_dims_saveable` or its alias `checkpoint_dots_with_no_batch_dims`
404-
* `save_anything_but_these_names` (save any values except for the output of
405-
`checkpoint_name` with any of the names given)
406-
* `save_any_names_but_these` (save only named values, i.e. any outputs of
407-
`checkpoint_name`, except for those with the names given)
408-
* `save_only_these_names` (save only named values, and only among the names
409-
given)
410-
411399
Policies only indicate what is saveable; a value is only saved if it's actually needed by the backward pass.
412400

413401
+++ {"id": "lixGsLNwxQo7"}

jax/_src/ad_checkpoint.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,30 +56,37 @@
5656
### Policies
5757

5858
def everything_saveable(*_, **__) -> bool:
59-
# This is the effective policy without any use of jax.remat.
59+
"""The default strategy, as if ``jax.checkpoint`` were not being used at all.
60+
61+
This is the effective policy without any use of jax.remat."""
6062
return True
6163

6264
def nothing_saveable(*_, **__) -> bool:
63-
# This is the effective policy when using jax.remat without explicit policy.
65+
"""Rematerialize everything, as if a custom policy were not being used at all.
66+
67+
This is the effective policy when using jax.remat without explicit policy."""
6468
return False
6569

6670
def dots_saveable(prim, *_, **__) -> bool:
67-
# Matrix multiplies are expensive, so let's save them (and nothing else).
71+
"""Matrix multiplies are expensive, so let's save them (and nothing else)."""
6872
return prim in {lax_internal.dot_general_p,
6973
lax_convolution.conv_general_dilated_p}
7074
checkpoint_dots = dots_saveable
7175

7276
def dots_with_no_batch_dims_saveable(prim, *_, **params) -> bool:
73-
# This is a useful heuristic for transformers.
77+
"""This is a useful heuristic for transformers."""
7478
if prim is lax_internal.dot_general_p:
7579
(_, _), (lhs_b, rhs_b) = params['dimension_numbers']
7680
if not lhs_b and not rhs_b:
7781
return True
7882
return False
7983

8084
def offload_dot_with_no_batch_dims(offload_src, offload_dst):
85+
"""Same as ``dots_with_no_batch_dims_saveable``, but offload to CPU memory
86+
instead of recomputing.
87+
88+
This is a useful heuristic for transformers."""
8189
def policy(prim, *_, **params):
82-
# This is a useful heuristic for transformers.
8390
if prim is lax_internal.dot_general_p:
8491
(_, _), (lhs_b, rhs_b) = params['dimension_numbers']
8592
if not lhs_b and not rhs_b:
@@ -100,7 +107,8 @@ def policy(prim, *_, **params):
100107
return policy
101108

102109
def save_any_names_but_these(*names_not_to_save):
103-
"""Save only named values, excluding the names given."""
110+
"""Save only named values, i.e. any outputs of `checkpoint_name`, excluding
111+
the names given."""
104112
names_not_to_save = frozenset(names_not_to_save)
105113
def policy(prim, *_, **params):
106114
if prim is name_p:
@@ -120,6 +128,8 @@ def policy(prim, *_, **params):
120128
def save_and_offload_only_these_names(
121129
*, names_which_can_be_saved, names_which_can_be_offloaded,
122130
offload_src, offload_dst):
131+
"""Same as ``save_only_these_names``, but offload to CPU memory instead of
132+
recomputing."""
123133
names_which_can_be_saved = set(names_which_can_be_saved)
124134
names_which_can_be_offloaded = set(names_which_can_be_offloaded)
125135
intersection = names_which_can_be_saved.intersection(names_which_can_be_offloaded)
@@ -140,7 +150,9 @@ def policy(prim, *_, **params):
140150

141151

142152
def save_from_both_policies(policy_1, policy_2):
153+
"""Logical OR of the given policies.
143154
155+
A residual is saveable iff it is saveable according to either policy."""
144156
def policy(prim, *args, **params):
145157
out1 = policy_1(prim, *args, **params)
146158
out2 = policy_2(prim, *args, **params)

0 commit comments

Comments
 (0)