Skip to content

Commit b7d9e94

Browse files
authored
[CI/Build] Fix FlashInfer double build in Dockerfile (#20651)
Signed-off-by: mgoin <mgoin64@gmail.com>
1 parent 7c12a76 commit b7d9e94

File tree

1 file changed

+12
-16
lines changed

1 file changed

+12
-16
lines changed

docker/Dockerfile

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -387,30 +387,26 @@ RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
387387
if [[ "$CUDA_VERSION" == 12.8* ]]; then
388388
uv pip install --system ${FLASHINFER_CUDA128_INDEX_URL}/${FLASHINFER_CUDA128_WHEEL}
389389
else
390-
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0'
391-
git clone ${FLASHINFER_GIT_REPO} --single-branch --branch ${FLASHINFER_GIT_REF} --recursive
392-
# Needed to build AOT kernels
393-
(cd flashinfer && \
394-
python3 -m flashinfer.aot && \
395-
uv pip install --system --no-build-isolation . \
396-
)
397-
rm -rf flashinfer
398-
399-
# Default arches (skipping 10.0a and 12.0 since these need 12.8)
390+
# Exclude CUDA arches for older versions (11.x and 12.0-12.7)
400391
# TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
401-
TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
402392
if [[ "${CUDA_VERSION}" == 11.* ]]; then
403-
TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
393+
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
394+
elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then
395+
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
396+
else
397+
# CUDA 12.8+ supports 10.0a and 12.0
398+
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
404399
fi
405-
echo "🏗️ Building FlashInfer for arches: ${TORCH_CUDA_ARCH_LIST}"
400+
echo "🏗️ Building FlashInfer for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
406401

407402
git clone --depth 1 --recursive --shallow-submodules \
408-
--branch v0.2.6.post1 \
409-
https://github.com/flashinfer-ai/flashinfer.git flashinfer
403+
--branch ${FLASHINFER_GIT_REF} \
404+
${FLASHINFER_GIT_REPO} flashinfer
410405

406+
# Needed to build AOT kernels
411407
pushd flashinfer
412408
python3 -m flashinfer.aot
413-
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST}" \
409+
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
414410
uv pip install --system --no-build-isolation .
415411
popd
416412

0 commit comments

Comments
 (0)