Skip to content

Commit afb775c

Browse files
author
jax authors
committed
Jax persistent compilation cache user guide.
This user guide covers using the cache on local filesystems and Google Cloud. PiperOrigin-RevId: 623236335
1 parent 828e60c commit afb775c

File tree

2 files changed

+75
-1
lines changed

2 files changed

+75
-1
lines changed

docs/persistent_compilation_cache.md

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Persistent Compilation Cache
2+
3+
JAX has an optional disk cache for compiled programs. If enabled, JAX will
4+
store copies of compiled programs on disk, which can save recompilation time
5+
when running the same or similar tasks repeatedly.
6+
7+
## Usage
8+
9+
The compilation cache is enabled when the
10+
[cache-location](https://github.com/google/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206)
11+
is set. This should be done prior to the first compilation. Set the location as
12+
follows:
13+
14+
```
15+
import jax
16+
17+
# Make sure this is called before jax runs any operations!
18+
jax.config.update("jax_compilation_cache_dir", "cache-location")
19+
```
20+
21+
See the sections below for more detail on `cache-location`.
22+
23+
[`set_cache_dir()`](https://github.com/google/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18)
24+
is an alternate way of setting `cache-location`.
25+
26+
### Local filesystem
27+
28+
`cache-location` can be a directory on the local filesystem. For example:
29+
30+
```
31+
import jax
32+
33+
jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache")
34+
```
35+
36+
Note: the cache does not have an eviction mechanism implemented. If the
37+
cache-location is a directory in the local filesystem, its size will continue
38+
to grow unless files are manually deleted.
39+
40+
### Google Cloud
41+
42+
When running on Google Cloud, the compilation cache can be placed on a Google
43+
Cloud Storage (GCS) bucket. We recommend the following configuration:
44+
45+
* Create the bucket in the same region as where the workload will run.
46+
47+
* Create the bucket in the same project as the workload’s VM(s). Ensure that
48+
permissions are set so that the VM(s) can write to the bucket.
49+
50+
* There is no need for replication for smaller workloads. Larger workloads
51+
could benefit from replication.
52+
53+
* Use “Standard” for the default storage class for the bucket.
54+
55+
* Set the soft delete policy to its shortest: 7 days.
56+
57+
* Set the object lifecycle to the expected duration of the workload run.
58+
For example, if the workload is expected to run for 10 days, set the object
59+
lifecycle to 10 days. That should cover restarts that occur during the entire
60+
run. Use `age` for the lifecycle condition and `Delete` for the action. See
61+
[Object Lifecycle Management](https://cloud.google.com/storage/docs/lifecycle)
62+
for details. If the object lifecycle is not set, the cache will continue to
63+
grow since there is no eviction mechanism implemented.
64+
65+
* All encryption policies are supported.
66+
67+
Assuming that `gs://jax-cache` is the GCS bucket, set `cache-location` as
68+
follows:
69+
70+
```
71+
import jax
72+
73+
jax.config.update("jax_compilation_cache_dir", "gs://jax-cache")
74+
```

docs/user_guides.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ or deployed codebases.
1515
device_memory_profiling
1616
debugging/index
1717
gpu_performance_tips
18-
18+
persistent_compilation_cache
1919

2020
.. toctree::
2121
:maxdepth: 1

0 commit comments

Comments
 (0)