Skip to content

feat(torch): Update to PyTorch 2.7 & CUDA 12.8.1 #93

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 13, 2025
Merged
10 changes: 5 additions & 5 deletions .github/configurations/torch-base.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cuda: [ 12.8.0, 12.6.3, 12.4.1 ]
cuda: [ 12.8.1, 12.6.3, 12.4.1 ]
os: [ ubuntu22.04 ]
abi: [ 1, 0 ]
abi: [ 1 ]
include:
- torch: 2.6.0
vision: 0.21.0
audio: 2.6.0
- torch: 2.7.0
vision: 0.22.0
audio: 2.7.0
14 changes: 7 additions & 7 deletions .github/configurations/torch-nccl.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
cuda: [ 12.8.0, 12.6.3, 12.4.1 ]
cuda: [ 12.8.1, 12.6.3, 12.4.1 ]
os: [ ubuntu22.04 ]
abi: [ 1, 0 ]
abi: [ 1 ]
include:
- torch: 2.6.0
vision: 0.21.0
audio: 2.6.0
nccl: 2.25.1-1
nccl-tests-hash: 57fa979
- torch: 2.7.0
vision: 0.22.0
audio: 2.7.0
nccl: 2.26.5-1
nccl-tests-hash: 4658d92
9 changes: 6 additions & 3 deletions torch-extras/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ RUN apt-get -qq update && apt-get -qq install -y \
{ wget -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc \
| gpg --dearmor -o /etc/apt/trusted.gpg.d/kitware.gpg; } && \
apt-add-repository "deb https://apt.kitware.com/ubuntu/ $(lsb_release -cs) main" && \
apt-get -qq update && apt-get -qq install -y cmake && apt-get clean
apt-get -qq update && \
apt-get -qq install -y 'cmake=3.31.6-*' 'cmake-data=3.31.6-*' && \
apt-get clean && \
python3 -m pip install --no-cache-dir 'cmake==3.31.6'

# Update compiler (GCC) and linker (LLD) versions
# gfortran-11 is just for compiler_wrapper.f95
Expand Down Expand Up @@ -105,10 +108,10 @@ RUN if [ "$(uname -m)" = "aarch64" ]; then \
COPY --chmod=755 effective_cpu_count.sh .
COPY --chmod=755 scale.sh .

ARG BUILD_NVCC_APPEND_FLAGS="-gencode=arch=compute_90a,code=compute_90a"
ARG BUILD_NVCC_APPEND_FLAGS="-gencode=arch=compute_90a,code=sm_90a"
RUN FLAGS="$BUILD_NVCC_APPEND_FLAGS" && \
case "${NV_CUDA_LIB_VERSION}" in 12.[89].*) \
FLAGS="${FLAGS} -gencode=arch=compute_100a,code=sm_100a" ;; \
FLAGS="${FLAGS} -gencode=arch=compute_100,code=sm_100 -gencode=arch=compute_100a,code=sm_100a" ;; \
esac && \
echo "-Wno-deprecated-gpu-targets -diag-suppress 191,186,177${FLAGS:+ $FLAGS}" > /build/nvcc.conf
ARG BUILD_MAX_JOBS
Expand Down
18 changes: 11 additions & 7 deletions torch/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# syntax=docker/dockerfile:1.7
ARG BUILDER_BASE_IMAGE="nvidia/cuda:12.8.0-devel-ubuntu22.04"
ARG FINAL_BASE_IMAGE="nvidia/cuda:12.8.0-base-ubuntu22.04"
ARG BUILDER_BASE_IMAGE="nvidia/cuda:12.8.1-devel-ubuntu22.04"
ARG FINAL_BASE_IMAGE="nvidia/cuda:12.8.1-base-ubuntu22.04"

ARG BUILD_TORCH_VERSION="2.5.1"
ARG BUILD_TORCH_VISION_VERSION="0.20.0"
ARG BUILD_TORCH_AUDIO_VERSION="2.5.0"
ARG BUILD_TORCH_VERSION="2.7.0"
ARG BUILD_TORCH_VISION_VERSION="0.22.0"
ARG BUILD_TORCH_AUDIO_VERSION="2.7.0"
ARG BUILD_TRANSFORMERENGINE_VERSION="1.13"
ARG BUILD_FLASH_ATTN_VERSION="2.7.4.post1"
ARG BUILD_FLASH_ATTN_3_VERSION="2.7.2.post1"
Expand Down Expand Up @@ -174,7 +174,10 @@ RUN apt-get -qq update && apt-get -qq install -y \
{ wget -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc \
| gpg --dearmor -o /etc/apt/trusted.gpg.d/kitware.gpg; } && \
apt-add-repository -n "deb https://apt.kitware.com/ubuntu/ $(lsb_release -cs) main" && \
apt-get -qq update && apt-get -qq install -y cmake && apt-get clean
apt-get -qq update && \
apt-get -qq install -y 'cmake=3.31.6-*' 'cmake-data=3.31.6-*' && \
apt-get clean && \
python3 -m pip install --no-cache-dir 'cmake==3.31.6'

RUN mkdir /tmp/ccache-install && \
cd /tmp/ccache-install && \
Expand Down Expand Up @@ -340,7 +343,7 @@ ARG BUILD_NVCC_APPEND_FLAGS="-gencode=arch=compute_90a,code=sm_90a"
# Add sm_100a build if NV_CUDA_LIB_VERSION matches 12.[89].*
RUN FLAGS="$BUILD_NVCC_APPEND_FLAGS" && \
case "${NV_CUDA_LIB_VERSION}" in 12.[89].*) \
FLAGS="${FLAGS} -gencode=arch=compute_100a,code=sm_100a" ;; \
FLAGS="${FLAGS} -gencode=arch=compute_100,code=sm_100 -gencode=arch=compute_100a,code=sm_100a" ;; \
esac && \
echo "-Wno-deprecated-gpu-targets -diag-suppress 191,186,177${FLAGS:+ $FLAGS}" > /build/nvcc.conf

Expand Down Expand Up @@ -659,6 +662,7 @@ RUN export \
cuda-nvrtc-${CUDA_PACKAGE_VERSION} \
libcusparse-${CUDA_PACKAGE_VERSION} \
libcusolver-${CUDA_PACKAGE_VERSION} \
libcufile-${CUDA_PACKAGE_VERSION} \
cuda-cupti-${CUDA_PACKAGE_VERSION} \
libnvjpeg-${CUDA_PACKAGE_VERSION} \
libnvtoolsext1 && \
Expand Down
Loading