Skip to content

Commit 94c1232

Browse files
committed
Merge remote-tracking branch 'origin/main' into kv-xfer-updates
2 parents 2c6cc8f + d373ec7 commit 94c1232

File tree

69 files changed

+5863
-352
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+5863
-352
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@ steps:
4141
# TODO: add `--strict` once warnings in docstrings are fixed
4242
- mkdocs build
4343

44+
- label: Pytorch Nightly Dependency Override Check # 2min
45+
# if this test fails, it means the nightly torch version is not compatible with some
46+
# of the dependencies. Please check the error message and add the package to whitelist
47+
# in /vllm/tools/generate_nightly_torch_test.py
48+
soft_fail: true
49+
source_file_dependencies:
50+
- requirements/nightly_torch_test.txt
51+
commands:
52+
- bash standalone_tests/pytorch_nightly_dependency.sh
53+
4454
- label: Async Engine, Inputs, Utils, Worker Test # 24min
4555
mirror_hardwares: [amdexperimental]
4656
source_file_dependencies:
@@ -168,6 +178,23 @@ steps:
168178
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
169179
- popd
170180

181+
- label: EPLB Algorithm Test
182+
working_dir: "/vllm-workspace/tests"
183+
source_file_dependencies:
184+
- vllm/distributed/eplb
185+
- tests/distributed/test_eplb_algo.py
186+
commands:
187+
- pytest -v -s distributed/test_eplb_algo.py
188+
189+
- label: EPLB Execution Test # 5min
190+
working_dir: "/vllm-workspace/tests"
191+
num_gpus: 4
192+
source_file_dependencies:
193+
- vllm/distributed/eplb
194+
- tests/distributed/test_eplb_execute.py
195+
commands:
196+
- pytest -v -s distributed/test_eplb_execute.py
197+
171198
- label: Metrics, Tracing Test # 10min
172199
mirror_hardwares: [amdexperimental, amdproduction]
173200
num_gpus: 2
@@ -750,7 +777,7 @@ steps:
750777
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt
751778

752779
- label: Weight Loading Multiple GPU Test - Large Models # optional
753-
mirror_hardwares: [amdexperimental]
780+
mirror_hardwares: [amdexperimental]
754781
working_dir: "/vllm-workspace/tests"
755782
num_gpus: 2
756783
gpu: a100

.pre-commit-config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ repos:
5353
files: ^requirements/test\.(in|txt)$
5454
- repo: local
5555
hooks:
56+
- id: format-torch-nightly-test
57+
name: reformat nightly_torch_test.txt to be in sync with test.in
58+
language: python
59+
entry: python tools/generate_nightly_torch_test.py
60+
files: ^requirements/test\.(in|txt)$
5661
- id: mypy-local
5762
name: Run mypy for local Python installation
5863
entry: tools/mypy.sh 0 "local"

CMakeLists.txt

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
513513
CUDA_ARCHS "${FP4_ARCHS}")
514514
list(APPEND VLLM_EXT_SRC "${SRCS}")
515515
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4=1")
516+
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
516517
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
517518
else()
518519
message(STATUS "Not building NVFP4 as no compatible archs were found.")
@@ -547,8 +548,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
547548
# if it's possible to compile MoE kernels that use its output.
548549
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
549550
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
550-
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
551-
"csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
551+
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu")
552552
set_gencode_flags_for_srcs(
553553
SRCS "${SRCS}"
554554
CUDA_ARCHS "${SCALED_MM_ARCHS}")
@@ -566,6 +566,16 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
566566
endif()
567567
endif()
568568

569+
# moe_data.cu is used by all CUTLASS MoE kernels.
570+
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
571+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
572+
set(SRCS "csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
573+
set_gencode_flags_for_srcs(
574+
SRCS "${SRCS}"
575+
CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
576+
list(APPEND VLLM_EXT_SRC "${SRCS}")
577+
endif()
578+
569579
#
570580
# Machete kernels
571581

@@ -638,6 +648,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
638648
# if CUDA endif
639649
endif()
640650

651+
if (VLLM_GPU_LANG STREQUAL "HIP")
652+
# Add QuickReduce kernels
653+
list(APPEND VLLM_EXT_SRC
654+
"csrc/custom_quickreduce.cu"
655+
)
656+
# if ROCM endif
657+
endif()
658+
641659
message(STATUS "Enabling C extension.")
642660
define_gpu_extension_target(
643661
_C

csrc/custom_quickreduce.cu

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
#include <ATen/cuda/Exceptions.h>
2+
#include <c10/cuda/CUDAGuard.h>
3+
#include <c10/cuda/CUDAStream.h>
4+
#include <torch/all.h>
5+
6+
#ifdef USE_ROCM
7+
8+
#include "quickreduce/quick_reduce.h"
9+
10+
quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size,
11+
std::optional<int64_t> qr_max_size) {
12+
if (world_size > 8)
13+
throw std::invalid_argument("world size > 8 is not supported");
14+
if (world_size == 6)
15+
throw std::invalid_argument("world size == 6 is not supported");
16+
if (world_size % 2 != 0)
17+
throw std::invalid_argument("Odd num gpus is not supported for now");
18+
if (rank < 0 || rank >= world_size)
19+
throw std::invalid_argument("invalid rank passed in");
20+
quickreduce::DeviceComms* fptr = new quickreduce::DeviceComms();
21+
fptr->init(world_size, rank, qr_max_size);
22+
return (quickreduce::fptr_t)fptr;
23+
}
24+
25+
void qr_destroy(quickreduce::fptr_t _fa) {
26+
if (_fa) {
27+
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
28+
fa->destroy();
29+
delete fa;
30+
}
31+
}
32+
33+
torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) {
34+
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
35+
hipIpcMemHandle_t handle = fa->get_handle();
36+
auto options =
37+
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
38+
auto data_handle =
39+
torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options);
40+
std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t));
41+
return data_handle;
42+
}
43+
44+
void qr_open_handles(quickreduce::fptr_t _fa,
45+
const std::vector<torch::Tensor>& handles) {
46+
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
47+
std::vector<hipIpcMemHandle_t> ipc_handles;
48+
ipc_handles.reserve(handles.size());
49+
for (auto& handle : handles) {
50+
// Ensure the tensor is on the same device as the current device.
51+
hipIpcMemHandle_t ipc_handle;
52+
std::memcpy(&ipc_handle, handle.data_ptr(), sizeof(hipIpcMemHandle_t));
53+
ipc_handles.push_back(ipc_handle);
54+
}
55+
fa->open_ipc_handles(ipc_handles);
56+
}
57+
58+
void qr_all_reduce(quickreduce::fptr_t _fa, torch::Tensor& inp,
59+
torch::Tensor& out, int64_t quant_level, bool cast_bf2half) {
60+
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
61+
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
62+
auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA();
63+
64+
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
65+
TORCH_CHECK_EQ(inp.numel(), out.numel());
66+
TORCH_CHECK_LE(out.numel(), fa->kMaxProblemSize);
67+
if (out.scalar_type() == at::ScalarType::Half) {
68+
fa->allreduce<half, false>(reinterpret_cast<half*>(inp.data_ptr()),
69+
reinterpret_cast<half*>(out.data_ptr()),
70+
out.numel(), quant_level, stream);
71+
} else if (out.scalar_type() == at::ScalarType::BFloat16) {
72+
if (cast_bf2half) {
73+
fa->allreduce<half, true>(reinterpret_cast<half*>(inp.data_ptr()),
74+
reinterpret_cast<half*>(out.data_ptr()),
75+
out.numel(), quant_level, stream);
76+
} else {
77+
fa->allreduce<quickreduce::nv_bfloat16, false>(
78+
reinterpret_cast<quickreduce::nv_bfloat16*>(inp.data_ptr()),
79+
reinterpret_cast<quickreduce::nv_bfloat16*>(out.data_ptr()),
80+
out.numel(), quant_level, stream);
81+
}
82+
} else {
83+
throw std::runtime_error(
84+
"quick allreduce only supports float16 and bfloat16");
85+
}
86+
}
87+
88+
int64_t qr_max_size() {
89+
// The default is 2GB (2,147,483,648 bytes)
90+
return static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1;
91+
}
92+
93+
#define INSTANTIATE_FOR_WORLDSIZE(T, Codec, cast_bf2half) \
94+
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 2>, \
95+
cast_bf2half>; \
96+
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 4>, \
97+
cast_bf2half>; \
98+
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 8>, cast_bf2half>;
99+
100+
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, false)
101+
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, false)
102+
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, false)
103+
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, false)
104+
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, true)
105+
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, true)
106+
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, true)
107+
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, true)
108+
109+
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecFP, false)
110+
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ4, false)
111+
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ6, false)
112+
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ8, false)
113+
114+
#endif // USE_ROCM

csrc/ops.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,3 +360,14 @@ std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
360360
int64_t size);
361361
int64_t open_mem_handle(torch::Tensor& mem_handle);
362362
void free_shared_buffer(int64_t buffer);
363+
364+
#ifdef USE_ROCM
365+
fptr_t init_custom_qr(int64_t rank, int64_t world_size,
366+
std::optional<int64_t> qr_max_size = std::nullopt);
367+
void qr_destroy(fptr_t _fa);
368+
torch::Tensor qr_get_handle(fptr_t _fa);
369+
void qr_open_handles(fptr_t _fa, const std::vector<torch::Tensor>& handles);
370+
void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
371+
int64_t quant_level, bool cast_bf2half = false);
372+
int64_t qr_max_size();
373+
#endif

csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,40 +29,40 @@ struct sm100_fp8_config_default {
2929
template <typename InType, typename OutType,
3030
template <typename, typename, typename> typename Epilogue>
3131
struct sm100_fp8_config_M256 {
32-
// M in (128, 256]
32+
// M in (64, 256]
3333
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
3434
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
3535
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
3636
using TileShape = Shape<_128, _128, _128>;
37-
using ClusterShape = Shape<_2, _2, _1>;
37+
using ClusterShape = Shape<_2, _1, _1>;
3838
using Cutlass3xGemm =
3939
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
4040
KernelSchedule, EpilogueSchedule>;
4141
};
4242

4343
template <typename InType, typename OutType,
4444
template <typename, typename, typename> typename Epilogue>
45-
struct sm100_fp8_config_M128 {
46-
// M in (64, 128]
45+
struct sm100_fp8_config_M64 {
46+
// M in (16, 64]
4747
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
4848
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
4949
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
50-
using TileShape = Shape<_128, _128, _256>;
51-
using ClusterShape = Shape<_2, _4, _1>;
50+
using TileShape = Shape<_64, _64, _128>;
51+
using ClusterShape = Shape<_1, _1, _1>;
5252
using Cutlass3xGemm =
5353
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
5454
KernelSchedule, EpilogueSchedule>;
5555
};
5656

5757
template <typename InType, typename OutType,
5858
template <typename, typename, typename> typename Epilogue>
59-
struct sm100_fp8_config_M64 {
60-
// M in [1, 64]
59+
struct sm100_fp8_config_M16 {
60+
// M in [1, 16]
6161
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
6262
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
6363
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
64-
using TileShape = Shape<_64, _64, _256>;
65-
using ClusterShape = Shape<_1, _8, _1>;
64+
using TileShape = Shape<_64, _64, _128>;
65+
using ClusterShape = Shape<_1, _4, _1>;
6666
using Cutlass3xGemm =
6767
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
6868
KernelSchedule, EpilogueSchedule>;
@@ -82,27 +82,27 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
8282
using Cutlass3xGemmDefault =
8383
typename sm100_fp8_config_default<InType, OutType,
8484
Epilogue>::Cutlass3xGemm;
85+
using Cutlass3xGemmM16 =
86+
typename sm100_fp8_config_M16<InType, OutType, Epilogue>::Cutlass3xGemm;
8587
using Cutlass3xGemmM64 =
8688
typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
87-
using Cutlass3xGemmM128 =
88-
typename sm100_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
8989
using Cutlass3xGemmM256 =
9090
typename sm100_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
9191

9292
uint32_t const m = a.size(0);
9393
uint32_t const mp2 =
94-
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
94+
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
9595

96-
if (mp2 <= 64) {
97-
// m in [1, 64]
98-
return cutlass_gemm_caller<Cutlass3xGemmM64>(
96+
if (mp2 <= 16) {
97+
// m in [1, 16]
98+
return cutlass_gemm_caller<Cutlass3xGemmM16>(
9999
out, a, b, std::forward<EpilogueArgs>(args)...);
100-
} else if (mp2 <= 128) {
101-
// m in (64, 128]
102-
return cutlass_gemm_caller<Cutlass3xGemmM128>(
100+
} else if (mp2 <= 64) {
101+
// m in (16, 64]
102+
return cutlass_gemm_caller<Cutlass3xGemmM64>(
103103
out, a, b, std::forward<EpilogueArgs>(args)...);
104104
} else if (mp2 <= 256) {
105-
// m in (128, 256]
105+
// m in (64, 256]
106106
return cutlass_gemm_caller<Cutlass3xGemmM256>(
107107
out, a, b, std::forward<EpilogueArgs>(args)...);
108108
} else {

csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ void get_cutlass_moe_mm_data(
241241
// mm to run it for.
242242
int32_t version_num = get_sm_version_num();
243243
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
244-
(defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90)
244+
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
245245
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
246246
problem_sizes2, input_permutation,
247247
output_permutation, num_experts, n, k,
@@ -252,7 +252,7 @@ void get_cutlass_moe_mm_data(
252252
false,
253253
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
254254
"CUDA device capability: ",
255-
version_num, ". Required capability: 90");
255+
version_num, ". Required capability: 90 or 100");
256256
}
257257

258258
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
@@ -265,7 +265,8 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
265265
// This function currently gets compiled only if we have a valid cutlass moe
266266
// mm to run it for.
267267
int32_t version_num = get_sm_version_num();
268-
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
268+
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
269+
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
269270
get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1,
270271
problem_sizes2, expert_num_tokens,
271272
num_local_experts, padded_m, n, k);
@@ -275,7 +276,7 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
275276
false,
276277
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
277278
"for CUDA device capability: ",
278-
version_num, ". Required capability: 90");
279+
version_num, ". Required capability: 90 or 100");
279280
}
280281

281282
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,

0 commit comments

Comments
 (0)