|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +from collections import Counter |
15 | 16 | from functools import partial
|
16 | 17 | import math
|
17 | 18 | import os
|
18 | 19 | import platform
|
19 | 20 | import tempfile
|
20 |
| -from collections import Counter |
21 | 21 | import unittest
|
22 | 22 | from unittest import mock
|
23 | 23 | from unittest import SkipTest
|
|
33 | 33 | from jax._src import compiler
|
34 | 34 | from jax._src import config
|
35 | 35 | from jax._src import distributed
|
36 |
| -from jax._src.maps import xmap |
37 | 36 | from jax._src import monitoring
|
38 | 37 | from jax._src import test_util as jtu
|
39 | 38 | from jax._src import xla_bridge
|
40 | 39 | from jax._src.lib import xla_client
|
| 40 | +from jax._src.lib import xla_extension_version |
| 41 | +from jax._src.maps import xmap |
41 | 42 | from jax.experimental.pjit import pjit
|
42 | 43 | from jax.sharding import PartitionSpec as P
|
43 | 44 | import numpy as np
|
@@ -71,9 +72,9 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
71 | 72 |
|
72 | 73 | def setUp(self):
|
73 | 74 | super().setUp()
|
74 |
| - # TODO(b/323256224): Add back support for CPU together with extra fields in |
75 |
| - # a cache key with underlying hardware features. |
76 | 75 | supported_platforms = ["tpu", "gpu"]
|
| 76 | + if xla_extension_version >= 253: |
| 77 | + supported_platforms.append("cpu") |
77 | 78 |
|
78 | 79 | if not jtu.test_device_matches(supported_platforms):
|
79 | 80 | raise SkipTest(
|
|
0 commit comments