@@ -56,6 +56,14 @@ RUN apt-get install -y ocl-icd-libopencl1 clinfo libboost-all-dev && \
56
56
echo "libnvidia-opencl.so.1" > /etc/OpenCL/vendors/nvidia.icd && \
57
57
/tmp/clean-layer.sh
58
58
59
+ # When using pip in a conda environment, conda commands should be ran first and then
60
+ # the remaining pip commands: https://www.anaconda.com/using-pip-in-a-conda-environment/
61
+ # However, because this image is based on the CPU image, this isn't possible but better
62
+ # to put them at the top of this file to minize conflicts.
63
+ RUN conda remove --force -y pytorch torchvision torchaudio cpuonly && \
64
+ conda install -y pytorch torchvision torchaudio cudatoolkit=$CUDA_MAJOR_VERSION.$CUDA_MINOR_VERSION -c pytorch && \
65
+ /tmp/clean-layer.sh
66
+
59
67
# Install LightGBM with GPU
60
68
RUN pip uninstall -y lightgbm && \
61
69
cd /usr/local/src && \
@@ -72,21 +80,22 @@ RUN pip uninstall -y lightgbm && \
72
80
/tmp/clean-layer.sh
73
81
74
82
# Install JAX
83
+ # b/154150582#comment9: JAX 0.1.63 with jaxlib 0.1.43 is causing the GPU tests to hang.
84
+ ENV JAX_VERSION=0.1.62
85
+ ENV JAXLIB_VERSION=0.1.41
75
86
ENV JAX_PYTHON_VERSION=cp37
76
87
ENV JAX_CUDA_VERSION=cuda$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION
77
88
ENV JAX_PLATFORM=linux_x86_64
78
89
ENV JAX_BASE_URL="https://storage.googleapis.com/jax-releases"
79
90
80
- RUN pip install --upgrade $JAX_BASE_URL/$JAX_CUDA_VERSION/jaxlib-0.1.43 -$JAX_PYTHON_VERSION-none-$JAX_PLATFORM.whl && \
81
- pip install --upgrade jax
91
+ RUN pip install $JAX_BASE_URL/$JAX_CUDA_VERSION/jaxlib-$JAXLIB_VERSION -$JAX_PYTHON_VERSION-none-$JAX_PLATFORM.whl && \
92
+ pip install jax==$JAX_VERSION
82
93
83
94
# Reinstall packages with a separate version for GPU support.
84
95
COPY --from=tensorflow_whl /tmp/tensorflow_gpu/*.whl /tmp/tensorflow_gpu/
85
96
RUN pip uninstall -y tensorflow && \
86
97
pip install /tmp/tensorflow_gpu/tensorflow*.whl && \
87
98
rm -rf /tmp/tensorflow_gpu && \
88
- conda remove --force -y pytorch torchvision torchaudio cpuonly && \
89
- conda install -y pytorch torchvision torchaudio cudatoolkit=$CUDA_MAJOR_VERSION.$CUDA_MINOR_VERSION -c pytorch && \
90
99
pip uninstall -y mxnet && \
91
100
# b/126259508 --no-deps prevents numpy from being downgraded.
92
101
pip install --no-deps mxnet-cu$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION && \
0 commit comments