Skip to content

Commit a3082bd

Browse files
feat: Add GPU enabled Dockerfile (#748)
* Add NVIDIA GPU enabled Dockerfile for the machine learning framework backends * Add .dockerignore to reduce build context size
1 parent c125bd1 commit a3082bd

File tree

3 files changed

+85
-0
lines changed

3 files changed

+85
-0
lines changed

.dockerignore

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Ignore everything
2+
**
3+
4+
# Except the following
5+
!/.git/**
6+
!/docker/**
7+
!/src/**
8+
!/LICENSE
9+
!/pyproject.toml
10+
!/README.md
11+
!/setup.cfg
12+
!/setup.py

docker/gpu/Dockerfile

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
FROM nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04 as base
2+
3+
FROM base as builder
4+
# hadolint ignore=DL3015
5+
RUN apt-get update -y && \
6+
apt-get install -y \
7+
git \
8+
python3 \
9+
python3-pip && \
10+
apt-get -y autoclean && \
11+
apt-get -y autoremove && \
12+
rm -rf /var/lib/apt-get/lists/*
13+
COPY . /code
14+
COPY ./docker/gpu/install_backend.sh /code/install_backend.sh
15+
WORKDIR /code
16+
ARG BACKEND=tensorflow
17+
RUN python3 -m pip install --upgrade --no-cache-dir pip setuptools wheel && \
18+
/bin/bash install_backend.sh ${BACKEND} && \
19+
python3 -m pip list
20+
21+
FROM base
22+
# Use C.UTF-8 locale to avoid issues with ASCII encoding
23+
ENV LC_ALL=C.UTF-8
24+
ENV LANG=C.UTF-8
25+
COPY --from=builder /lib/x86_64-linux-gnu /lib/x86_64-linux-gnu
26+
COPY --from=builder /usr/local /usr/local
27+
COPY --from=builder /usr/bin/python3 /usr/bin/python3
28+
COPY --from=builder /usr/bin/python3.6 /usr/bin/python3.6
29+
COPY --from=builder /usr/bin/pip3 /usr/bin/pip3
30+
COPY --from=builder /usr/lib/python3 /usr/lib/python3
31+
COPY --from=builder /usr/lib/python3.6 /usr/lib/python3.6
32+
COPY --from=builder /usr/lib/x86_64-linux-gnu /usr/lib/x86_64-linux-gnu
33+
ENTRYPOINT ["/usr/local/bin/pyhf"]

docker/gpu/install_backend.sh

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#!/bin/bash
2+
3+
set -e
4+
5+
function get_JAXLIB_GPU_WHEEL {
6+
# c.f. https://github.com/google/jax#pip-installation
7+
local PYTHON_VERSION # alternatives: cp35, cp36, cp37, cp38
8+
PYTHON_VERSION="cp"$(python3 --version | awk '{print $NF}' | awk '{split($0, rel, "."); print rel[1]rel[2]}')
9+
local CUDA_VERSION # alternatives: cuda90, cuda92, cuda100, cuda101
10+
CUDA_VERSION="cuda"$(< /usr/local/cuda/version.txt awk '{print $NF}' | awk '{split($0, rel, "."); print rel[1]rel[2]}')
11+
local PLATFORM=linux_x86_64
12+
local JAXLIB_VERSION=0.1.37
13+
local BASE_URL="https://storage.googleapis.com/jax-releases"
14+
local JAXLIB_GPU_WHEEL="${BASE_URL}/${CUDA_VERSION}/jaxlib-${JAXLIB_VERSION}-${PYTHON_VERSION}-none-${PLATFORM}.whl"
15+
echo "${JAXLIB_GPU_WHEEL}"
16+
}
17+
18+
function install_backend() {
19+
# 1: the backend option name in setup.py
20+
local backend="${1}"
21+
if [[ "${backend}" == "tensorflow" ]]; then
22+
# shellcheck disable=SC2102
23+
python3 -m pip install --no-cache-dir .[xmlio,tensorflow]
24+
elif [[ "${backend}" == "torch" ]]; then
25+
# shellcheck disable=SC2102
26+
python3 -m pip install --no-cache-dir .[xmlio,torch]
27+
elif [[ "${backend}" == "jax" ]]; then
28+
python3 -m pip install --no-cache-dir .[xmlio]
29+
python3 -m pip install --no-cache-dir "$(get_JAXLIB_GPU_WHEEL)"
30+
python3 -m pip install --no-cache-dir jax
31+
fi
32+
}
33+
34+
function main() {
35+
# 1: the backend option name in setup.py
36+
local BACKEND="${1}"
37+
install_backend "${BACKEND}"
38+
}
39+
40+
main "$@" || exit 1

0 commit comments

Comments
 (0)