@@ -387,30 +387,26 @@ RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
387
387
if [[ "$CUDA_VERSION" == 12.8* ]]; then
388
388
uv pip install --system ${FLASHINFER_CUDA128_INDEX_URL}/${FLASHINFER_CUDA128_WHEEL}
389
389
else
390
- export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0'
391
- git clone ${FLASHINFER_GIT_REPO} --single-branch --branch ${FLASHINFER_GIT_REF} --recursive
392
- # Needed to build AOT kernels
393
- (cd flashinfer && \
394
- python3 -m flashinfer.aot && \
395
- uv pip install --system --no-build-isolation . \
396
- )
397
- rm -rf flashinfer
398
-
399
- # Default arches (skipping 10.0a and 12.0 since these need 12.8)
390
+ # Exclude CUDA arches for older versions (11.x and 12.0-12.7)
400
391
# TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
401
- TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
402
392
if [[ "${CUDA_VERSION}" == 11.* ]]; then
403
- TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
393
+ FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
394
+ elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then
395
+ FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
396
+ else
397
+ # CUDA 12.8+ supports 10.0a and 12.0
398
+ FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
404
399
fi
405
- echo "🏗️ Building FlashInfer for arches: ${TORCH_CUDA_ARCH_LIST }"
400
+ echo "🏗️ Building FlashInfer for arches: ${FI_TORCH_CUDA_ARCH_LIST }"
406
401
407
402
git clone --depth 1 --recursive --shallow-submodules \
408
- --branch v0.2.6.post1 \
409
- https://github.com/flashinfer-ai/flashinfer.git flashinfer
403
+ --branch ${FLASHINFER_GIT_REF} \
404
+ ${FLASHINFER_GIT_REPO} flashinfer
410
405
406
+ # Needed to build AOT kernels
411
407
pushd flashinfer
412
408
python3 -m flashinfer.aot
413
- TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST }" \
409
+ TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST }" \
414
410
uv pip install --system --no-build-isolation .
415
411
popd
416
412
0 commit comments