Skip to content

Commit bdb84e2

Browse files
authored
[Bugfix] Fixes for FlashInfer's TORCH_CUDA_ARCH_LIST (#20136)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Tyler Michael Smith <tysmith@redhat.com>
1 parent 3dd3591 commit bdb84e2

File tree

1 file changed

+38
-17
lines changed

1 file changed

+38
-17
lines changed

docker/Dockerfile

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -374,23 +374,44 @@ ARG FLASHINFER_CUDA128_INDEX_URL="https://download.pytorch.org/whl/cu128/flashin
374374
ARG FLASHINFER_CUDA128_WHEEL="flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl"
375375
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
376376
ARG FLASHINFER_GIT_REF="v0.2.6.post1"
377-
RUN --mount=type=cache,target=/root/.cache/uv \
378-
. /etc/environment && \
379-
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
380-
# FlashInfer already has a wheel for PyTorch 2.7.0 and CUDA 12.8. This is enough for CI use
381-
if [[ "$CUDA_VERSION" == 12.8* ]]; then \
382-
uv pip install --system ${FLASHINFER_CUDA128_INDEX_URL}/${FLASHINFER_CUDA128_WHEEL} ; \
383-
else \
384-
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0' && \
385-
git clone ${FLASHINFER_GIT_REPO} --single-branch --branch ${FLASHINFER_GIT_REF} --recursive && \
386-
# Needed to build AOT kernels
387-
(cd flashinfer && \
388-
python3 -m flashinfer.aot && \
389-
uv pip install --system --no-build-isolation . \
390-
) && \
391-
rm -rf flashinfer; \
392-
fi \
393-
fi
377+
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
378+
. /etc/environment
379+
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then
380+
# FlashInfer already has a wheel for PyTorch 2.7.0 and CUDA 12.8. This is enough for CI use
381+
if [[ "$CUDA_VERSION" == 12.8* ]]; then
382+
uv pip install --system ${FLASHINFER_CUDA128_INDEX_URL}/${FLASHINFER_CUDA128_WHEEL}
383+
else
384+
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0'
385+
git clone ${FLASHINFER_GIT_REPO} --single-branch --branch ${FLASHINFER_GIT_REF} --recursive
386+
# Needed to build AOT kernels
387+
(cd flashinfer && \
388+
python3 -m flashinfer.aot && \
389+
uv pip install --system --no-build-isolation . \
390+
)
391+
rm -rf flashinfer
392+
393+
# Default arches (skipping 10.0a and 12.0 since these need 12.8)
394+
# TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
395+
TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
396+
if [[ "${CUDA_VERSION}" == 11.* ]]; then
397+
TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
398+
fi
399+
echo "🏗️ Building FlashInfer for arches: ${TORCH_CUDA_ARCH_LIST}"
400+
401+
git clone --depth 1 --recursive --shallow-submodules \
402+
--branch v0.2.6.post1 \
403+
https://github.com/flashinfer-ai/flashinfer.git flashinfer
404+
405+
pushd flashinfer
406+
python3 -m flashinfer.aot
407+
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST}" \
408+
uv pip install --system --no-build-isolation .
409+
popd
410+
411+
rm -rf flashinfer
412+
fi \
413+
fi
414+
BASH
394415
COPY examples examples
395416
COPY benchmarks benchmarks
396417
COPY ./vllm/collect_env.py .

0 commit comments

Comments
 (0)