Skip to content

Commit babb56e

Browse files
authored
Merge pull request #782 from Kaggle/pin-jaxlib
Pin JAX and jaxlib
2 parents 45d5ce6 + 0964410 commit babb56e

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

Dockerfile

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,6 @@ ENV PROJ_LIB=/opt/conda/share/proj
3939
RUN conda install -c conda-forge matplotlib basemap cartopy python-igraph imagemagick pysal && \
4040
# b/142337634#comment22 pin required to avoid torchaudio downgrade.
4141
conda install -c pytorch pytorch torchvision "torchaudio>=0.4.0" cpuonly && \
42-
# pixman 0.38.0 has a known issue causing issues with packages such as pycairo & openslide.
43-
# See: https://gitlab.freedesktop.org/pixman/pixman/commit/8256c235d9b3854d039242356905eca854a890ba
44-
conda install -c conda-forge --no-deps pixman==0.34 && \
4542
/tmp/clean-layer.sh
4643

4744
# The anaconda base image includes outdated versions of these packages. Update them to include the latest version.

gpu.Dockerfile

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ RUN apt-get install -y ocl-icd-libopencl1 clinfo libboost-all-dev && \
5656
echo "libnvidia-opencl.so.1" > /etc/OpenCL/vendors/nvidia.icd && \
5757
/tmp/clean-layer.sh
5858

59+
# When using pip in a conda environment, conda commands should be ran first and then
60+
# the remaining pip commands: https://www.anaconda.com/using-pip-in-a-conda-environment/
61+
# However, because this image is based on the CPU image, this isn't possible but better
62+
# to put them at the top of this file to minize conflicts.
63+
RUN conda remove --force -y pytorch torchvision torchaudio cpuonly && \
64+
conda install -y pytorch torchvision torchaudio cudatoolkit=$CUDA_MAJOR_VERSION.$CUDA_MINOR_VERSION -c pytorch && \
65+
/tmp/clean-layer.sh
66+
5967
# Install LightGBM with GPU
6068
RUN pip uninstall -y lightgbm && \
6169
cd /usr/local/src && \
@@ -72,21 +80,22 @@ RUN pip uninstall -y lightgbm && \
7280
/tmp/clean-layer.sh
7381

7482
# Install JAX
83+
# b/154150582#comment9: JAX 0.1.63 with jaxlib 0.1.43 is causing the GPU tests to hang.
84+
ENV JAX_VERSION=0.1.62
85+
ENV JAXLIB_VERSION=0.1.41
7586
ENV JAX_PYTHON_VERSION=cp37
7687
ENV JAX_CUDA_VERSION=cuda$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION
7788
ENV JAX_PLATFORM=linux_x86_64
7889
ENV JAX_BASE_URL="https://storage.googleapis.com/jax-releases"
7990

80-
RUN pip install --upgrade $JAX_BASE_URL/$JAX_CUDA_VERSION/jaxlib-0.1.43-$JAX_PYTHON_VERSION-none-$JAX_PLATFORM.whl && \
81-
pip install --upgrade jax
91+
RUN pip install $JAX_BASE_URL/$JAX_CUDA_VERSION/jaxlib-$JAXLIB_VERSION-$JAX_PYTHON_VERSION-none-$JAX_PLATFORM.whl && \
92+
pip install jax==$JAX_VERSION
8293

8394
# Reinstall packages with a separate version for GPU support.
8495
COPY --from=tensorflow_whl /tmp/tensorflow_gpu/*.whl /tmp/tensorflow_gpu/
8596
RUN pip uninstall -y tensorflow && \
8697
pip install /tmp/tensorflow_gpu/tensorflow*.whl && \
8798
rm -rf /tmp/tensorflow_gpu && \
88-
conda remove --force -y pytorch torchvision torchaudio cpuonly && \
89-
conda install -y pytorch torchvision torchaudio cudatoolkit=$CUDA_MAJOR_VERSION.$CUDA_MINOR_VERSION -c pytorch && \
9099
pip uninstall -y mxnet && \
91100
# b/126259508 --no-deps prevents numpy from being downgraded.
92101
pip install --no-deps mxnet-cu$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION && \

0 commit comments

Comments
 (0)