You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# When tensorflow is compatible with being installed alongside JAX/Pytorch then we no longer need to include the wheel and can install it directly.
15
+
# RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-${TENSORFLOW_VERSION}/tensorflow-${TENSORFLOW_VERSION}-cp38-cp38-linux_x86_64.whl
16
+
RUN mkdir -p /lib/wheels && curl --output /lib/wheels/tensorflow-${TENSORFLOW_VERSION}-cp38-cp38-linux_x86_64.whl https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-${TENSORFLOW_VERSION}/tensorflow-${TENSORFLOW_VERSION}-cp38-cp38-linux_x86_64.whl
17
+
RUN curl --output /lib/libtpu.so https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.3.0/libtpu.so
RUN pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html && \
33
-
cp $DEFAULT_LIBTPU $JAX_LIBTPU && \
34
-
/tmp/clean-layer.sh
36
+
RUN cp $PIP_LIBTPU $PYTORCH_LIBTPU
35
37
36
38
# Monkey-patch TF, JAX & PYTORCH to load the correct libtpu.so when they are imported:
37
-
RUN sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_LIBRARY_PATH'] = '${PYTORCH_LIBTPU}'|" /opt/conda/lib/python3.7/site-packages/torch_xla/__init__.py && \
38
-
sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_LIBRARY_PATH'] = '${JAX_LIBTPU}'|" /opt/conda/lib/python3.7/site-packages/jax/_src/cloud_tpu_init.py && \
39
-
sed -i "1s/^/from jax._src.cloud_tpu_init import cloud_tpu_init\ncloud_tpu_init()\n/" /opt/conda/lib/python3.7/site-packages/tensorflow/__init__.py
39
+
RUN sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_LIBRARY_PATH'] = '${PYTORCH_LIBTPU}'|" /usr/local/lib/python3.8/site-packages/torch_xla/__init__.py && \
40
+
sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_LIBRARY_PATH'] = '${JAX_LIBTPU}'|" /usr/local/lib/python3.8/site-packages/jax/_src/cloud_tpu_init.py
41
+
42
+
# Packages needed by the Notebook editor:
43
+
RUN pip install papermill jupyterlab python-lsp-server[all] jupyterlab-lsp
44
+
45
+
# Additional useful packages should be added here:
46
+
RUN pip install pandas
40
47
41
48
# Set these env vars so that they don't produce errs calling the metadata server to load them:
42
49
ENV TPU_ACCELERATOR_TYPE=v3-8
43
-
ENV TPU_PROCESS_ADDRESSES=local
50
+
ENV TPU_PROCESS_ADDRESSES=local
51
+
52
+
# Metadata
53
+
ARG GIT_COMMIT=unknown
54
+
ARG BUILD_DATE=unknown
55
+
56
+
LABEL git-commit=$GIT_COMMIT
57
+
LABEL build-date=$BUILD_DATE
58
+
ENV GIT_COMMIT=${GIT_COMMIT}
59
+
ENV BUILD_DATE=${BUILD_DATE}
60
+
61
+
LABEL tensorflow-version=$TENSORFLOW_VERSION
62
+
LABEL kaggle-lang=python
63
+
64
+
# Correlate current release with the git hash inside the kernel editor by running `!cat /etc/git_commit`.
65
+
RUN echo "$GIT_COMMIT" > /etc/git_commit && echo "$BUILD_DATE" > /etc/build_date
0 commit comments