Skip to content

Commit 88dd29a

Browse files
author
jax authors
committed
Re-enable persistent cache on cpu.
CPU cache key now includes machine attributes, so there should no longer be a problem with incompatible CPUs accessing the same cache entry. PiperOrigin-RevId: 621341638
1 parent 1baed9b commit 88dd29a

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

jax/_src/compiler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from jax._src import traceback_util
3434
from jax._src.interpreters import mlir
3535
from jax._src.lib import xla_client as xc
36+
from jax._src.lib import xla_extension_version
3637
from jax._src.lib.mlir import ir
3738
from jax._src.xla_bridge import process_count
3839
import numpy as np
@@ -253,8 +254,8 @@ def compile_or_get_cached(
253254
# that supports serialization of executables.
254255
# TODO(skye): add warning when initializing cache on unsupported default platform
255256
supported_platforms = ["tpu", "gpu"]
256-
# TODO(b/323256224): Add back support for CPU together with extra fields in a
257-
# cache key with underlying hardware features (xla_extension_version >= 230).
257+
if xla_extension_version >= 253:
258+
supported_platforms.append("cpu")
258259
use_compilation_cache = (
259260
config.enable_compilation_cache.value
260261
and getattr(backend, "supports_executable_serialization", True)

tests/compilation_cache_test.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from collections import Counter
1516
from functools import partial
1617
import math
1718
import os
1819
import platform
1920
import tempfile
20-
from collections import Counter
2121
import unittest
2222
from unittest import mock
2323
from unittest import SkipTest
@@ -33,11 +33,12 @@
3333
from jax._src import compiler
3434
from jax._src import config
3535
from jax._src import distributed
36-
from jax._src.maps import xmap
3736
from jax._src import monitoring
3837
from jax._src import test_util as jtu
3938
from jax._src import xla_bridge
4039
from jax._src.lib import xla_client
40+
from jax._src.lib import xla_extension_version
41+
from jax._src.maps import xmap
4142
from jax.experimental.pjit import pjit
4243
from jax.sharding import PartitionSpec as P
4344
import numpy as np
@@ -71,9 +72,9 @@ class CompilationCacheTest(jtu.JaxTestCase):
7172

7273
def setUp(self):
7374
super().setUp()
74-
# TODO(b/323256224): Add back support for CPU together with extra fields in
75-
# a cache key with underlying hardware features.
7675
supported_platforms = ["tpu", "gpu"]
76+
if xla_extension_version >= 253:
77+
supported_platforms.append("cpu")
7778

7879
if not jtu.test_device_matches(supported_platforms):
7980
raise SkipTest(

0 commit comments

Comments
 (0)