File tree Expand file tree Collapse file tree 3 files changed +7
-9
lines changed Expand file tree Collapse file tree 3 files changed +7
-9
lines changed Original file line number Diff line number Diff line change @@ -420,6 +420,8 @@ RUN pip install flashtext && \
420
420
# pycrypto is used by competitions team.
421
421
pip install pycrypto && \
422
422
pip install easyocr && \
423
+ # Keep JAX version in sync with GPU image.
424
+ pip install jax==0.2.12 jaxlib==0.1.64 && \
423
425
/tmp/clean-layer.sh
424
426
425
427
# Download base easyocr models.
Original file line number Diff line number Diff line change @@ -77,7 +77,7 @@ RUN pip uninstall -y lightgbm && \
77
77
echo "libnvidia-opencl.so.1" > /etc/OpenCL/vendors/nvidia.icd && \
78
78
/tmp/clean-layer.sh
79
79
80
- # Install JAX
80
+ # Install JAX (Keep JAX version in sync with CPU image)
81
81
RUN pip install jax==0.2.12 jaxlib==0.1.64+cuda$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION -f https://storage.googleapis.com/jax-releases/jax_releases.html && \
82
82
/tmp/clean-layer.sh
83
83
Original file line number Diff line number Diff line change 1
1
import unittest
2
-
3
2
import time
4
3
4
+ import jax .numpy as np
5
+
5
6
from common import gpu_test
7
+ from jax import grad , jit
6
8
7
9
8
10
class TestJAX (unittest .TestCase ):
9
11
def tanh (self , x ):
10
- import jax .numpy as np
11
12
y = np .exp (- 2.0 * x )
12
13
return (1.0 - y ) / (1.0 + y )
13
14
14
- @gpu_test
15
- def test_JAX (self ):
16
- # importing inside the gpu-only test because these packages can't be
17
- # imported on the CPU image since they are not present there.
18
- from jax import grad , jit
19
-
15
+ def test_grad (self ):
20
16
grad_tanh = grad (self .tanh )
21
17
ag = grad_tanh (1.0 )
22
18
self .assertEqual (0.4199743 , ag )
You can’t perform that action at this time.
0 commit comments