Skip to content

Commit 1466c79

Browse files
committed
Merge remote-tracking branch 'upstream/main' into upstream_merge_2025_05_15
2 parents c13eddf + 0ceaebf commit 1466c79

File tree

45 files changed

+1289
-528
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1289
-528
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ steps:
151151
# test with tp=2 and external_dp=2
152152
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
153153
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
154+
# test with tp=2 and pp=2
155+
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
154156
# test with internal dp
155157
- python3 ../examples/offline_inference/data_parallel.py
156158
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
@@ -390,7 +392,7 @@ steps:
390392

391393
- label: Tensorizer Test # 11min
392394
working_dir: "/vllm-workspace/tests"
393-
mirror_hardwares: [amdexperimental]
395+
mirror_hardwares: [amdexperimental, amdproduction]
394396
soft_fail: true
395397
source_file_dependencies:
396398
- vllm/model_executor/model_loader

CMakeLists.txt

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
302302
# Only build Marlin kernels if we are building for at least some compatible archs.
303303
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
304304
# are not supported by Machete yet.
305-
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
305+
# 9.0 for latest bf16 atomicAdd PTX
306+
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;9.0+PTX" "${CUDA_ARCHS}")
306307
if (MARLIN_ARCHS)
307308

308309
#
@@ -446,8 +447,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
446447
#
447448
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
448449
# kernels for the remaining archs that are not already built for 3x.
450+
# (Build 8.9 for FP8)
449451
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
450-
"7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
452+
"7.5;8.0;8.9+PTX" "${CUDA_ARCHS}")
451453
# subtract out the archs that are already built for 3x
452454
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
453455
if (SCALED_MM_2X_ARCHS)
@@ -676,7 +678,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
676678
CUDA_ARCHS "${CUDA_ARCHS}")
677679

678680
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
679-
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
681+
# 9.0 for latest bf16 atomicAdd PTX
682+
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;9.0+PTX" "${CUDA_ARCHS}")
680683
if (MARLIN_MOE_ARCHS)
681684

682685
#

benchmarks/cutlass_benchmarks/w8a8_benchmarks.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,16 @@ def bench_fp8(
115115
a_cont = a.contiguous()
116116
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
117117
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
118-
block_scale_a = torch.rand((m, k // 128), device="cuda", dtype=torch.float32)
119-
block_scale_b = torch.rand((k // 128, n // 128), device="cuda", dtype=torch.float32)
118+
119+
def ceil_div(x: int, y: int) -> int:
120+
return (x + y - 1) // y
121+
122+
block_scale_a = torch.rand(
123+
(m, ceil_div(k, 128)), device="cuda", dtype=torch.float32
124+
)
125+
block_scale_b = torch.rand(
126+
ceil_div(k, 128), ceil_div(n, 128), device="cuda", dtype=torch.float32
127+
)
120128
block_scale_a_M_major = block_scale_a.t().contiguous().t()
121129
block_scale_b_K_major = block_scale_b.t().contiguous().t()
122130
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)

cmake/utils.cmake

Lines changed: 69 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -228,11 +228,26 @@ macro(set_gencode_flags_for_srcs)
228228
"${multiValueArgs}" ${ARGN} )
229229

230230
foreach(_ARCH ${arg_CUDA_ARCHS})
231-
string(REPLACE "." "" _ARCH "${_ARCH}")
232-
set_gencode_flag_for_srcs(
233-
SRCS ${arg_SRCS}
234-
ARCH "compute_${_ARCH}"
235-
CODE "sm_${_ARCH}")
231+
# handle +PTX suffix: generate both sm and ptx codes if requested
232+
string(FIND "${_ARCH}" "+PTX" _HAS_PTX)
233+
if(NOT _HAS_PTX EQUAL -1)
234+
string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}")
235+
string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}")
236+
set_gencode_flag_for_srcs(
237+
SRCS ${arg_SRCS}
238+
ARCH "compute_${_STRIPPED_ARCH}"
239+
CODE "sm_${_STRIPPED_ARCH}")
240+
set_gencode_flag_for_srcs(
241+
SRCS ${arg_SRCS}
242+
ARCH "compute_${_STRIPPED_ARCH}"
243+
CODE "compute_${_STRIPPED_ARCH}")
244+
else()
245+
string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}")
246+
set_gencode_flag_for_srcs(
247+
SRCS ${arg_SRCS}
248+
ARCH "compute_${_STRIPPED_ARCH}"
249+
CODE "sm_${_STRIPPED_ARCH}")
250+
endif()
236251
endforeach()
237252

238253
if (${arg_BUILD_PTX_FOR_ARCH})
@@ -251,7 +266,10 @@ endmacro()
251266
#
252267
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
253268
# `<major>.<minor>[letter]` compute the "loose intersection" with the
254-
# `TGT_CUDA_ARCHS` list of gencodes.
269+
# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
270+
# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
271+
# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
272+
# architecture in `SRC_CUDA_ARCHS`.
255273
# The loose intersection is defined as:
256274
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
257275
# where `<=` is the version comparison operator.
@@ -268,44 +286,63 @@ endmacro()
268286
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
269287
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
270288
#
289+
# Example With PTX:
290+
# SRC_CUDA_ARCHS="8.0+PTX"
291+
# TGT_CUDA_ARCHS="9.0"
292+
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
293+
# OUT_CUDA_ARCHS="8.0+PTX"
294+
#
271295
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
272-
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
273-
set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS})
296+
set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}")
297+
set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS})
298+
299+
# handle +PTX suffix: separate base arch for matching, record PTX requests
300+
set(_PTX_ARCHS)
301+
foreach(_arch ${_SRC_CUDA_ARCHS})
302+
if(_arch MATCHES "\\+PTX$")
303+
string(REPLACE "+PTX" "" _base "${_arch}")
304+
list(APPEND _PTX_ARCHS "${_base}")
305+
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
306+
list(APPEND _SRC_CUDA_ARCHS "${_base}")
307+
endif()
308+
endforeach()
309+
list(REMOVE_DUPLICATES _PTX_ARCHS)
310+
list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
274311

275312
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
276313
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
277314
set(_CUDA_ARCHS)
278-
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
279-
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
280-
if ("9.0" IN_LIST TGT_CUDA_ARCHS_)
281-
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0")
315+
if ("9.0a" IN_LIST _SRC_CUDA_ARCHS)
316+
list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a")
317+
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
318+
list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0")
282319
set(_CUDA_ARCHS "9.0a")
283320
endif()
284321
endif()
285322

286-
if ("10.0a" IN_LIST SRC_CUDA_ARCHS)
287-
list(REMOVE_ITEM SRC_CUDA_ARCHS "10.0a")
323+
if ("10.0a" IN_LIST _SRC_CUDA_ARCHS)
324+
list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a")
288325
if ("10.0" IN_LIST TGT_CUDA_ARCHS)
289-
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "10.0")
326+
list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0")
290327
set(_CUDA_ARCHS "10.0a")
291328
endif()
292329
endif()
293330

294-
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
331+
list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
295332

296333
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
297334
# is less or equal to ARCH (but has the same major version since SASS binary
298335
# compatibility is only forward compatible within the same major version).
299-
foreach(_ARCH ${TGT_CUDA_ARCHS_})
336+
foreach(_ARCH ${_TGT_CUDA_ARCHS})
300337
set(_TMP_ARCH)
301338
# Extract the major version of the target arch
302339
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
303-
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
340+
foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS})
304341
# Extract the major version of the source arch
305342
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
306-
# Check major-version match AND version-less-or-equal
343+
# Check version-less-or-equal, and allow PTX arches to match across majors
307344
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
308-
if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
345+
if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
309346
set(_TMP_ARCH "${_SRC_ARCH}")
310347
endif()
311348
else()
@@ -321,6 +358,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
321358
endforeach()
322359

323360
list(REMOVE_DUPLICATES _CUDA_ARCHS)
361+
362+
# reapply +PTX suffix to architectures that requested PTX
363+
set(_FINAL_ARCHS)
364+
foreach(_arch ${_CUDA_ARCHS})
365+
if(_arch IN_LIST _PTX_ARCHS)
366+
list(APPEND _FINAL_ARCHS "${_arch}+PTX")
367+
else()
368+
list(APPEND _FINAL_ARCHS "${_arch}")
369+
endif()
370+
endforeach()
371+
set(_CUDA_ARCHS ${_FINAL_ARCHS})
372+
324373
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
325374
endfunction()
326375

csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <torch/all.h>
22
#include "cuda_utils.h"
3+
#include "cutlass_extensions/common.hpp"
34

45
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
56
void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
@@ -28,29 +29,46 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
2829
}
2930
}
3031
} else {
31-
using GroupShape = std::array<int64_t, 2>;
32-
auto make_group_shape = [](torch::Tensor const& x,
33-
torch::Tensor const& s) -> GroupShape {
34-
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
35-
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
36-
cuda_utils::ceil_div(x.size(1), s.size(1))};
37-
};
32+
TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
33+
TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
34+
int32_t version_num = get_sm_version_num();
35+
if (version_num >= 100) {
36+
TORCH_CHECK(
37+
a.size(0) == a_scales.size(0) &&
38+
cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
39+
"a_scale_group_shape must be [1, 128].");
40+
TORCH_CHECK(
41+
cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
42+
cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
43+
"b_scale_group_shape must be [128, 128].");
44+
} else {
45+
// TODO: Remove this after using cutlass sm90 blockwise scaling gemm
46+
// kernel, or introducing ceil_div to the load_init() of mainloop.
47+
using GroupShape = std::array<int64_t, 2>;
48+
auto make_group_shape = [](torch::Tensor const& x,
49+
torch::Tensor const& s) -> GroupShape {
50+
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
51+
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
52+
cuda_utils::ceil_div(x.size(1), s.size(1))};
53+
};
54+
55+
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
56+
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
3857

39-
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
40-
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
58+
// 1x128 per-token group scales for activations
59+
// 128x128 blockwise scales for weights
60+
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
61+
b_scale_group_shape == GroupShape{128, 128} &&
62+
a.dtype() == torch::kFloat8_e4m3fn &&
63+
b.dtype() == torch::kFloat8_e4m3fn),
64+
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
65+
"a_scale_group_shape must be [1, 128]. Got: [",
66+
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
67+
"]\n"
68+
"b_scale_group_shape must be [128, 128]. Got: [",
69+
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
70+
}
4171

42-
// 1x128 per-token group scales for activations
43-
// 128x128 blockwise scales for weights
44-
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
45-
b_scale_group_shape == GroupShape{128, 128} &&
46-
a.dtype() == torch::kFloat8_e4m3fn &&
47-
b.dtype() == torch::kFloat8_e4m3fn),
48-
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
49-
"a_scale_group_shape must be [1, 128]. Got: [",
50-
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
51-
"]\n"
52-
"b_scale_group_shape must be [128, 128]. Got: [",
53-
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
5472
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
5573
blockwise_func(c, a, b, a_scales, b_scales);
5674
}

examples/offline_inference/torchrun_example.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
see `tests/distributed/test_torchrun_example.py` for the unit test.
99
"""
1010

11+
import torch.distributed as dist
12+
1113
from vllm import LLM, SamplingParams
1214

1315
# Create prompts, the same across all ranks
@@ -27,23 +29,26 @@
2729
# all ranks have the same random seed, so that sampling can be
2830
# deterministic across ranks.
2931
llm = LLM(
30-
model="facebook/opt-125m",
32+
model="meta-llama/Llama-3.1-8B",
3133
tensor_parallel_size=2,
34+
pipeline_parallel_size=2,
3235
distributed_executor_backend="external_launcher",
33-
seed=0,
36+
max_model_len=32768,
37+
seed=1,
3438
)
3539

3640
outputs = llm.generate(prompts, sampling_params)
3741

3842
# all ranks will have the same outputs
39-
print("-" * 50)
40-
for output in outputs:
41-
prompt = output.prompt
42-
generated_text = output.outputs[0].text
43-
print(f"Prompt: {prompt!r}\n"
44-
f"Generated text: {generated_text!r}")
43+
if dist.get_rank() == 0:
4544
print("-" * 50)
46-
"""
45+
for output in outputs:
46+
prompt = output.prompt
47+
generated_text = output.outputs[0].text
48+
print(f"Prompt: {prompt!r}\n"
49+
f"Generated text: {generated_text!r}\n")
50+
print("-" * 50)
51+
"""
4752
Further tips:
4853
4954
1. to communicate control messages across all ranks, use the cpu group,

0 commit comments

Comments
 (0)