|
| 1 | +#!/bin/bash |
| 2 | + |
| 3 | +set -e |
| 4 | + |
| 5 | +function get_JAXLIB_GPU_WHEEL { |
| 6 | + # c.f. https://github.com/google/jax#pip-installation |
| 7 | + local PYTHON_VERSION # alternatives: cp35, cp36, cp37, cp38 |
| 8 | + PYTHON_VERSION="cp"$(python3 --version | awk '{print $NF}' | awk '{split($0, rel, "."); print rel[1]rel[2]}') |
| 9 | + local CUDA_VERSION # alternatives: cuda90, cuda92, cuda100, cuda101 |
| 10 | + CUDA_VERSION="cuda"$(< /usr/local/cuda/version.txt awk '{print $NF}' | awk '{split($0, rel, "."); print rel[1]rel[2]}') |
| 11 | + local PLATFORM=linux_x86_64 |
| 12 | + local JAXLIB_VERSION=0.1.37 |
| 13 | + local BASE_URL="https://storage.googleapis.com/jax-releases" |
| 14 | + local JAXLIB_GPU_WHEEL="${BASE_URL}/${CUDA_VERSION}/jaxlib-${JAXLIB_VERSION}-${PYTHON_VERSION}-none-${PLATFORM}.whl" |
| 15 | + echo "${JAXLIB_GPU_WHEEL}" |
| 16 | +} |
| 17 | + |
| 18 | +function install_backend() { |
| 19 | + # 1: the backend option name in setup.py |
| 20 | + local backend="${1}" |
| 21 | + if [[ "${backend}" == "tensorflow" ]]; then |
| 22 | + # shellcheck disable=SC2102 |
| 23 | + python3 -m pip install --no-cache-dir .[xmlio,tensorflow] |
| 24 | + elif [[ "${backend}" == "torch" ]]; then |
| 25 | + # shellcheck disable=SC2102 |
| 26 | + python3 -m pip install --no-cache-dir .[xmlio,torch] |
| 27 | + elif [[ "${backend}" == "jax" ]]; then |
| 28 | + python3 -m pip install --no-cache-dir .[xmlio] |
| 29 | + python3 -m pip install --no-cache-dir "$(get_JAXLIB_GPU_WHEEL)" |
| 30 | + python3 -m pip install --no-cache-dir jax |
| 31 | + fi |
| 32 | +} |
| 33 | + |
| 34 | +function main() { |
| 35 | + # 1: the backend option name in setup.py |
| 36 | + local BACKEND="${1}" |
| 37 | + install_backend "${BACKEND}" |
| 38 | +} |
| 39 | + |
| 40 | +main "$@" || exit 1 |
0 commit comments