File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -72,13 +72,16 @@ RUN pip uninstall -y lightgbm && \
72
72
/tmp/clean-layer.sh
73
73
74
74
# Install JAX
75
+ # b/154150582#comment9: JAX 0.1.63 with jaxlib 0.1.43 is causing the GPU tests to hang.
76
+ ENV JAX_VERSION=0.1.62
77
+ ENV JAXLIB_VERSION=0.1.41
75
78
ENV JAX_PYTHON_VERSION=cp37
76
79
ENV JAX_CUDA_VERSION=cuda$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION
77
80
ENV JAX_PLATFORM=linux_x86_64
78
81
ENV JAX_BASE_URL="https://storage.googleapis.com/jax-releases"
79
82
80
- RUN pip install --upgrade $JAX_BASE_URL/$JAX_CUDA_VERSION/jaxlib-0.1.43 -$JAX_PYTHON_VERSION-none-$JAX_PLATFORM.whl && \
81
- pip install --upgrade jax
83
+ RUN pip install $JAX_BASE_URL/$JAX_CUDA_VERSION/jaxlib-$JAXLIB_VERSION -$JAX_PYTHON_VERSION-none-$JAX_PLATFORM.whl && \
84
+ pip install jax==$JAX_VERSION
82
85
83
86
# Reinstall packages with a separate version for GPU support.
84
87
COPY --from=tensorflow_whl /tmp/tensorflow_gpu/*.whl /tmp/tensorflow_gpu/
You can’t perform that action at this time.
0 commit comments