Skip to content

Commit 894baa4

Browse files
authored
Merge pull request #809 from Kaggle/clean-layer-jax
Call clean-layer script after JAX installation
2 parents 019d15e + bc42fea commit 894baa4

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

gpu.Dockerfile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ ENV JAX_PLATFORM=linux_x86_64
8989
ENV JAX_BASE_URL="https://storage.googleapis.com/jax-releases"
9090

9191
RUN pip install $JAX_BASE_URL/$JAX_CUDA_VERSION/jaxlib-$JAXLIB_VERSION-$JAX_PYTHON_VERSION-none-$JAX_PLATFORM.whl && \
92-
pip install jax==$JAX_VERSION
92+
pip install jax==$JAX_VERSION && \
93+
/tmp/clean-layer.sh
9394

9495
# Reinstall packages with a separate version for GPU support.
9596
COPY --from=tensorflow_whl /tmp/tensorflow_gpu/*.whl /tmp/tensorflow_gpu/
@@ -113,4 +114,4 @@ RUN pip install pycuda && \
113114
# ADD patches/tensorboard/notebook.py /opt/conda/lib/python3.7/site-packages/tensorboard/notebook.py
114115

115116
# Remove the CUDA stubs.
116-
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH_NO_STUBS"
117+
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH_NO_STUBS"

0 commit comments

Comments
 (0)