Skip to content

Commit 0682232

Browse files
committed
Pin jax
1 parent 45d5ce6 commit 0682232

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

gpu.Dockerfile

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,16 @@ RUN pip uninstall -y lightgbm && \
7272
/tmp/clean-layer.sh
7373

7474
# Install JAX
75+
# b/154150582#comment9: JAX 0.1.63 with jaxlib 0.1.43 is causing the GPU tests to hang.
76+
ENV JAX_VERSION=0.1.62
77+
ENV JAXLIB_VERSION=0.1.41
7578
ENV JAX_PYTHON_VERSION=cp37
7679
ENV JAX_CUDA_VERSION=cuda$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION
7780
ENV JAX_PLATFORM=linux_x86_64
7881
ENV JAX_BASE_URL="https://storage.googleapis.com/jax-releases"
7982

80-
RUN pip install --upgrade $JAX_BASE_URL/$JAX_CUDA_VERSION/jaxlib-0.1.43-$JAX_PYTHON_VERSION-none-$JAX_PLATFORM.whl && \
81-
pip install --upgrade jax
83+
RUN pip install $JAX_BASE_URL/$JAX_CUDA_VERSION/jaxlib-$JAXLIB_VERSION-$JAX_PYTHON_VERSION-none-$JAX_PLATFORM.whl && \
84+
pip install jax==$JAX_VERSION
8285

8386
# Reinstall packages with a separate version for GPU support.
8487
COPY --from=tensorflow_whl /tmp/tensorflow_gpu/*.whl /tmp/tensorflow_gpu/

0 commit comments

Comments
 (0)