Skip to content

Commit c79b569

Browse files
authored
Add JAX to the CPU/TPU image. (#992)
Fixes #918 Before, we were installing JAX only on the GPU image. Included tests to prevent regression. http://b/177334844
1 parent 752a79c commit c79b569

File tree

3 files changed

+7
-9
lines changed

3 files changed

+7
-9
lines changed

Dockerfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,8 @@ RUN pip install flashtext && \
420420
# pycrypto is used by competitions team.
421421
pip install pycrypto && \
422422
pip install easyocr && \
423+
# Keep JAX version in sync with GPU image.
424+
pip install jax==0.2.12 jaxlib==0.1.64 && \
423425
/tmp/clean-layer.sh
424426

425427
# Download base easyocr models.

gpu.Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ RUN pip uninstall -y lightgbm && \
7777
echo "libnvidia-opencl.so.1" > /etc/OpenCL/vendors/nvidia.icd && \
7878
/tmp/clean-layer.sh
7979

80-
# Install JAX
80+
# Install JAX (Keep JAX version in sync with CPU image)
8181
RUN pip install jax==0.2.12 jaxlib==0.1.64+cuda$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION -f https://storage.googleapis.com/jax-releases/jax_releases.html && \
8282
/tmp/clean-layer.sh
8383

tests/test_jax.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,18 @@
11
import unittest
2-
32
import time
43

4+
import jax.numpy as np
5+
56
from common import gpu_test
7+
from jax import grad, jit
68

79

810
class TestJAX(unittest.TestCase):
911
def tanh(self, x):
10-
import jax.numpy as np
1112
y = np.exp(-2.0 * x)
1213
return (1.0 - y) / (1.0 + y)
1314

14-
@gpu_test
15-
def test_JAX(self):
16-
# importing inside the gpu-only test because these packages can't be
17-
# imported on the CPU image since they are not present there.
18-
from jax import grad, jit
19-
15+
def test_grad(self):
2016
grad_tanh = grad(self.tanh)
2117
ag = grad_tanh(1.0)
2218
self.assertEqual(0.4199743, ag)

0 commit comments

Comments
 (0)