@@ -374,23 +374,44 @@ ARG FLASHINFER_CUDA128_INDEX_URL="https://download.pytorch.org/whl/cu128/flashin
374
374
ARG FLASHINFER_CUDA128_WHEEL="flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl"
375
375
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
376
376
ARG FLASHINFER_GIT_REF="v0.2.6.post1"
377
- RUN --mount=type=cache,target=/root/.cache/uv \
378
- . /etc/environment && \
379
- if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
380
- # FlashInfer already has a wheel for PyTorch 2.7.0 and CUDA 12.8. This is enough for CI use
381
- if [[ "$CUDA_VERSION" == 12.8* ]]; then \
382
- uv pip install --system ${FLASHINFER_CUDA128_INDEX_URL}/${FLASHINFER_CUDA128_WHEEL} ; \
383
- else \
384
- export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0' && \
385
- git clone ${FLASHINFER_GIT_REPO} --single-branch --branch ${FLASHINFER_GIT_REF} --recursive && \
386
- # Needed to build AOT kernels
387
- (cd flashinfer && \
388
- python3 -m flashinfer.aot && \
389
- uv pip install --system --no-build-isolation . \
390
- ) && \
391
- rm -rf flashinfer; \
392
- fi \
393
- fi
377
+ RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
378
+ . /etc/environment
379
+ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then
380
+ # FlashInfer already has a wheel for PyTorch 2.7.0 and CUDA 12.8. This is enough for CI use
381
+ if [[ "$CUDA_VERSION" == 12.8* ]]; then
382
+ uv pip install --system ${FLASHINFER_CUDA128_INDEX_URL}/${FLASHINFER_CUDA128_WHEEL}
383
+ else
384
+ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0'
385
+ git clone ${FLASHINFER_GIT_REPO} --single-branch --branch ${FLASHINFER_GIT_REF} --recursive
386
+ # Needed to build AOT kernels
387
+ (cd flashinfer && \
388
+ python3 -m flashinfer.aot && \
389
+ uv pip install --system --no-build-isolation . \
390
+ )
391
+ rm -rf flashinfer
392
+
393
+ # Default arches (skipping 10.0a and 12.0 since these need 12.8)
394
+ # TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
395
+ TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
396
+ if [[ "${CUDA_VERSION}" == 11.* ]]; then
397
+ TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
398
+ fi
399
+ echo "🏗️ Building FlashInfer for arches: ${TORCH_CUDA_ARCH_LIST}"
400
+
401
+ git clone --depth 1 --recursive --shallow-submodules \
402
+ --branch v0.2.6.post1 \
403
+ https://github.com/flashinfer-ai/flashinfer.git flashinfer
404
+
405
+ pushd flashinfer
406
+ python3 -m flashinfer.aot
407
+ TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST}" \
408
+ uv pip install --system --no-build-isolation .
409
+ popd
410
+
411
+ rm -rf flashinfer
412
+ fi \
413
+ fi
414
+ BASH
394
415
COPY examples examples
395
416
COPY benchmarks benchmarks
396
417
COPY ./vllm/collect_env.py .
0 commit comments