Skip to content

feat(torch): Support compute capability 12.0 #99

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/configurations/torch-nccl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ include:
vision: 0.22.1
audio: 2.7.1
nccl: 2.27.3-1
nccl-tests-hash: d82e3c0
nccl-tests-hash: d554783
1 change: 1 addition & 0 deletions .github/workflows/torch-base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ jobs:
name: Build torch:base
needs: get-config
strategy:
fail-fast: false
matrix: ${{ fromJSON(needs.get-config.outputs.config) }}
uses: ./.github/workflows/torch.yml
secrets: inherit
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/torch-nccl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ jobs:
name: Build torch:nccl
needs: get-config
strategy:
fail-fast: false
matrix: ${{ fromJSON(needs.get-config.outputs.config) }}
uses: ./.github/workflows/torch.yml
secrets: inherit
Expand Down
6 changes: 4 additions & 2 deletions torch-extras/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ ARG DEEPSPEED_VERSION="0.14.4"
ARG APEX_COMMIT="a1df80457ba67d60cbdb0d3ddfb08a2702c821a8"
ARG DEEPSPEED_KERNELS_COMMIT="e77acc40b104696d4e73229b787d1ef29a9685b1"
ARG DEEPSPEED_KERNELS_CUDA_ARCH_LIST="80;86;89;90"
ARG XFORMERS_VERSION="0.0.28.post1"
ARG XFORMERS_VERSION="0.0.30"
ARG BUILD_MAX_JOBS=""

FROM alpine/git:2.36.3 as apex-downloader
Expand Down Expand Up @@ -111,7 +111,9 @@ COPY --chmod=755 scale.sh .
ARG BUILD_NVCC_APPEND_FLAGS="-gencode=arch=compute_90a,code=sm_90a"
RUN FLAGS="$BUILD_NVCC_APPEND_FLAGS" && \
case "${NV_CUDA_LIB_VERSION}" in 12.[89].*) \
FLAGS="${FLAGS} -gencode=arch=compute_100,code=sm_100 -gencode=arch=compute_100a,code=sm_100a" ;; \
FLAGS="${FLAGS}$( \
printf -- ' -gencode=arch=compute_%s,code=sm_%s' 120 120 100 100 100a 100a \
)" ;; \
esac && \
echo "-Wno-deprecated-gpu-targets -diag-suppress 191,186,177${FLAGS:+ $FLAGS}" > /build/nvcc.conf
ARG BUILD_MAX_JOBS
Expand Down
24 changes: 16 additions & 8 deletions torch/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ ARG BUILD_FLASH_ATTN_VERSION="2.7.4.post1"
ARG BUILD_FLASH_ATTN_3_VERSION="2.7.2.post1"
ARG BUILD_TRITON_VERSION=""
ARG BUILD_TRITON="1"
ARG BUILD_TORCH_CUDA_ARCH_LIST="7.0 8.0 8.9 9.0 10.0+PTX"
ARG BUILD_TRANSFORMERENGINE_CUDA_ARCH_LIST="70;80;89;90;100"
ARG BUILD_TORCH_CUDA_ARCH_LIST="8.0 8.9 9.0 10.0 12.0+PTX"
ARG BUILD_TRANSFORMERENGINE_CUDA_ARCH_LIST="80;89;90;100;120"

ARG AOCL_BASE="/opt/aocl"
ARG AOCL_VER="4.2.0"
Expand Down Expand Up @@ -53,6 +53,8 @@ RUN ./clone.sh pytorch/pytorch pytorch "${BUILD_TORCH_VERSION}" && \
if [ "${BUILD_TORCH_VERSION}" = '2.5.1' ]; then \
wget 'https://github.com/pytorch/pytorch/commit/1cdaf1d85f5e4b3f8952fd0737a1afeb16995d13.patch' -qO- \
| git -C pytorch apply; \
elif [ "${BUILD_TORCH_VERSION}" = '2.7.1' ]; then \
git -C pytorch cherry-pick -n b74be524547f7f025066f19eda3b53a887c244ba; \
fi && \
rm -rf pytorch/.git

Expand Down Expand Up @@ -336,8 +338,8 @@ RUN --mount=type=bind,from=triton-downloader,source=/git/triton,target=triton/,r

ARG BUILD_TORCH_VERSION
ENV TORCH_VERSION=$BUILD_TORCH_VERSION
# Filter out the 10.0 arch on CUDA versions != 12.8 and != 12.9
ENV TORCH_CUDA_ARCH_LIST="${CUDA_VERSION##12.8.*}||${TORCH_CUDA_ARCH_LIST/ 10.0/}||${TORCH_CUDA_ARCH_LIST}"
# Filter out the 10.0 & 12.0 arches on CUDA versions != 12.8 and != 12.9
ENV TORCH_CUDA_ARCH_LIST="${CUDA_VERSION##12.8.*}||${TORCH_CUDA_ARCH_LIST/ 10.0 12.0/}||${TORCH_CUDA_ARCH_LIST}"
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#12.9.?}"
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#||*||}"
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST%||*}"
Expand All @@ -346,10 +348,12 @@ ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#*||}"
RUN printf 'Arch: %s\nTORCH_CUDA_ARCH_LIST=%s\n' "$(uname -m)" "${TORCH_CUDA_ARCH_LIST}"

ARG BUILD_NVCC_APPEND_FLAGS="-gencode=arch=compute_90a,code=sm_90a"
# Add sm_100a build if NV_CUDA_LIB_VERSION matches 12.[89].*
# Add sm_100a & sm_120 builds if NV_CUDA_LIB_VERSION matches 12.[89].*
RUN FLAGS="$BUILD_NVCC_APPEND_FLAGS" && \
case "${NV_CUDA_LIB_VERSION}" in 12.[89].*) \
FLAGS="${FLAGS} -gencode=arch=compute_100,code=sm_100 -gencode=arch=compute_100a,code=sm_100a" ;; \
FLAGS="${FLAGS}$( \
printf -- ' -gencode=arch=compute_%s,code=sm_%s' 120 120 100 100 100a 100a \
)" ;; \
esac && \
echo "-Wno-deprecated-gpu-targets -diag-suppress 191,186,177${FLAGS:+ $FLAGS}" > /build/nvcc.conf

Expand Down Expand Up @@ -390,6 +394,10 @@ RUN --mount=type=bind,from=pytorch-downloader,source=/git/pytorch,target=pytorch
if [ -n "${BUILD_CXX11_ABI}" ]; then \
export _GLIBCXX_USE_CXX11_ABI="${BUILD_CXX11_ABI}"; \
fi && \
case "${NV_NVTX_VERSION}" in \
12.[0-8].*) ;; \
*) export USE_SYSTEM_NVTX=1 ;; \
esac && \
./storage-info.sh . && \
cd pytorch && \
../storage-info.sh . && \
Expand Down Expand Up @@ -652,8 +660,8 @@ ARG BUILD_TORCH_AUDIO_VERSION
ENV TORCH_VERSION=$BUILD_TORCH_VERSION
ENV TORCH_VISION_VERSION=$BUILD_TORCH_VISION_VERSION
ENV TORCH_AUDIO_VERSION=$BUILD_TORCH_AUDIO_VERSION
# Filter out the 10.0 arch on CUDA versions != 12.8 and != 12.9
ENV TORCH_CUDA_ARCH_LIST="${CUDA_VERSION##12.8.*}||${TORCH_CUDA_ARCH_LIST/ 10.0/}||${TORCH_CUDA_ARCH_LIST}"
# Filter out the 10.0 & 12.0 arches on CUDA versions != 12.8 and != 12.9
ENV TORCH_CUDA_ARCH_LIST="${CUDA_VERSION##12.8.*}||${TORCH_CUDA_ARCH_LIST/ 10.0 12.0/}||${TORCH_CUDA_ARCH_LIST}"
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#12.9.?}"
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#||*||}"
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST%||*}"
Expand Down
Loading