Skip to content

Commit f9c069c

Browse files
authored
Modularize fused experts and integrate PPLX kernels (#15956)
1 parent 418d2f8 commit f9c069c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+3835
-665
lines changed

csrc/activation_kernels.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
7070
int64_t num_tokens = input.numel() / input.size(-1); \
7171
dim3 grid(num_tokens); \
7272
dim3 block(std::min(d, 1024)); \
73+
if (num_tokens == 0) { \
74+
return; \
75+
} \
7376
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
7477
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
7578
VLLM_DISPATCH_FLOATING_TYPES( \

csrc/dispatch_utils.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,19 @@
6565
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
6666
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
6767

68+
#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
69+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
70+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
71+
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
72+
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
73+
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
74+
AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
75+
AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
76+
AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)
77+
6878
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
6979
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
80+
81+
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
82+
AT_DISPATCH_SWITCH( \
83+
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))

csrc/moe/moe_align_sum_kernels.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
326326
}
327327

328328
if (use_global_memory) {
329-
VLLM_DISPATCH_INTEGRAL_TYPES(
329+
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
330330
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
331331
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
332332
// tensors
@@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
351351
cumsum_buffer.data_ptr<int32_t>());
352352
});
353353
} else if (use_i16) {
354-
VLLM_DISPATCH_INTEGRAL_TYPES(
354+
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
355355
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
356356
// set dynamic shared mem
357357
auto kernel =
@@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
366366
topk_ids.numel());
367367
});
368368
} else {
369-
VLLM_DISPATCH_INTEGRAL_TYPES(
369+
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
370370
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
371371
auto kernel =
372372
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
@@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
391391
TORCH_CHECK(num_experts == 256,
392392
"sgl_moe_align_block_size kernel only supports deepseek v3.");
393393

394-
VLLM_DISPATCH_INTEGRAL_TYPES(
394+
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
395395
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
396396
// calc needed amount of shared mem for `cumsum` tensors
397397
auto options_int =

csrc/moe/topk_softmax_kernels.cu

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__
108108
}
109109
}
110110

111-
template <int TPB>
112-
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
113-
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
111+
template <int TPB, typename IndType>
112+
__launch_bounds__(TPB) __global__ void moeTopK(
113+
const float* inputs_after_softmax,
114+
const bool* finished,
115+
float* output,
116+
IndType* indices,
117+
int* source_rows,
118+
const int num_experts,
119+
const int k,
120+
const int start_expert,
121+
const int end_expert)
114122
{
115123

116124
using cub_kvp = cub::KeyValuePair<int, float>;
@@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax
182190
2) This implementation assumes k is small, but will work for any k.
183191
*/
184192

185-
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
193+
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, typename IndType>
186194
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
187-
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
195+
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices,
188196
int* source_rows, const int k, const int start_expert, const int end_expert)
189197
{
190198
// We begin by enforcing compile time assertions and setting up compile time constants.
@@ -397,8 +405,8 @@ struct TopkConstants
397405
};
398406
} // namespace detail
399407

400-
template <int EXPERTS, int WARPS_PER_TB>
401-
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
408+
template <int EXPERTS, int WARPS_PER_TB, typename IndType>
409+
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
402410
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
403411
{
404412
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
@@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
421429
token_expert_indices, num_tokens, topk, 0, num_experts, \
422430
stream);
423431

432+
template <typename IndType>
424433
void topkGatingSoftmaxKernelLauncher(
425434
const float* gating_output,
426435
float* topk_weights,
427-
int* topk_indicies,
436+
IndType* topk_indicies,
428437
int* token_expert_indices,
429438
float* softmax_workspace,
430439
const int num_tokens,
@@ -493,14 +502,32 @@ void topk_softmax(
493502
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
494503
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
495504
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
496-
vllm::moe::topkGatingSoftmaxKernelLauncher(
497-
gating_output.data_ptr<float>(),
498-
topk_weights.data_ptr<float>(),
499-
topk_indices.data_ptr<int>(),
500-
token_expert_indices.data_ptr<int>(),
501-
softmax_workspace.data_ptr<float>(),
502-
num_tokens,
503-
num_experts,
504-
topk,
505-
stream);
505+
506+
if(topk_indices.scalar_type() == at::ScalarType::Int)
507+
{
508+
vllm::moe::topkGatingSoftmaxKernelLauncher(
509+
gating_output.data_ptr<float>(),
510+
topk_weights.data_ptr<float>(),
511+
topk_indices.data_ptr<int>(),
512+
token_expert_indices.data_ptr<int>(),
513+
softmax_workspace.data_ptr<float>(),
514+
num_tokens,
515+
num_experts,
516+
topk,
517+
stream);
518+
}
519+
else
520+
{
521+
assert(topk_indices.scalar_type() == at::ScalarType::UInt32);
522+
vllm::moe::topkGatingSoftmaxKernelLauncher(
523+
gating_output.data_ptr<float>(),
524+
topk_weights.data_ptr<float>(),
525+
topk_indices.data_ptr<uint32_t>(),
526+
token_expert_indices.data_ptr<int>(),
527+
softmax_workspace.data_ptr<float>(),
528+
num_tokens,
529+
num_experts,
530+
topk,
531+
stream);
532+
}
506533
}

examples/offline_inference/data_parallel.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,17 @@ def parse_args():
6565
type=int,
6666
default=0,
6767
help="Master node port")
68+
parser.add_argument("--enforce-eager",
69+
action='store_true',
70+
help="Enforce eager mode execution.")
71+
parser.add_argument("--trust-remote-code",
72+
action='store_true',
73+
help="Trust remote code.")
6874
return parser.parse_args()
6975

7076

7177
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
72-
dp_master_port, GPUs_per_dp_rank):
78+
dp_master_port, GPUs_per_dp_rank, enforce_eager, trust_remote_code):
7379
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
7480
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
7581
os.environ["VLLM_DP_SIZE"] = str(dp_size)
@@ -109,10 +115,13 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
109115
max_tokens=[16, 20][global_dp_rank % 2])
110116

111117
# Create an LLM.
112-
llm = LLM(model=model,
113-
tensor_parallel_size=GPUs_per_dp_rank,
114-
enforce_eager=True,
115-
enable_expert_parallel=True)
118+
llm = LLM(
119+
model=model,
120+
tensor_parallel_size=GPUs_per_dp_rank,
121+
enforce_eager=enforce_eager,
122+
enable_expert_parallel=True,
123+
trust_remote_code=trust_remote_code,
124+
)
116125
outputs = llm.generate(prompts, sampling_params)
117126
# Print the outputs.
118127
for i, output in enumerate(outputs):
@@ -155,7 +164,8 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
155164
proc = Process(target=main,
156165
args=(args.model, dp_size, local_dp_rank,
157166
global_dp_rank, dp_master_ip, dp_master_port,
158-
tp_size))
167+
tp_size, args.enforce_eager,
168+
args.trust_remote_code))
159169
proc.start()
160170
procs.append(proc)
161171
exit_code = 0

tests/kernels/moe/test_batched_moe.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from dataclasses import dataclass
4+
5+
import pytest
6+
import torch
7+
import triton.language as tl
8+
9+
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
10+
invoke_moe_batched_triton_kernel)
11+
12+
13+
@dataclass
14+
class BatchedMMConfig:
15+
dtype: torch.dtype
16+
num_experts: int
17+
max_tokens_per_expert: int
18+
K: int
19+
N: int
20+
21+
22+
@dataclass
23+
class BatchedMMTensors:
24+
A: torch.Tensor # [E, max_tokens, K]
25+
B: torch.Tensor # [E, K, N] - column major
26+
C: torch.Tensor # [E, max_tokens, N]
27+
num_expert_tokens: torch.Tensor # [E]
28+
29+
@staticmethod
30+
def make_tensors(config: BatchedMMConfig):
31+
A = torch.randn(
32+
(config.num_experts, config.max_tokens_per_expert, config.K),
33+
device="cuda",
34+
dtype=config.dtype) / 10
35+
B = torch.randn((config.num_experts, config.N, config.K),
36+
device="cuda",
37+
dtype=config.dtype)
38+
C = torch.zeros(
39+
(config.num_experts, config.max_tokens_per_expert, config.N),
40+
device="cuda",
41+
dtype=config.dtype)
42+
num_expert_tokens = torch.randint(low=0,
43+
high=config.max_tokens_per_expert,
44+
size=(config.num_experts, ),
45+
device="cuda",
46+
dtype=torch.int32)
47+
return BatchedMMTensors(A, B, C, num_expert_tokens)
48+
49+
50+
def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
51+
num_expert_tokens: torch.Tensor) -> torch.Tensor:
52+
53+
num_expert_tokens_cpu = num_expert_tokens.clone()
54+
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
55+
num_experts = num_expert_tokens.size(0)
56+
57+
for e in range(num_experts):
58+
num_tokens = num_expert_tokens_cpu[e]
59+
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
60+
61+
return C
62+
63+
64+
@pytest.mark.parametrize("num_experts", [16, 32])
65+
@pytest.mark.parametrize("max_tokens_per_expert",
66+
[32, 64, 128, 192, 224, 256, 512])
67+
@pytest.mark.parametrize("K", [128, 256, 1024])
68+
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
69+
@pytest.mark.parametrize("dtype",
70+
[torch.float32, torch.float16, torch.bfloat16])
71+
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
72+
N: int, dtype: torch.dtype):
73+
74+
config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
75+
tensors = BatchedMMTensors.make_tensors(config)
76+
77+
test_output = tensors.C
78+
ref_output = test_output.clone()
79+
80+
compute_tl_dtype = {
81+
torch.float16: tl.float16,
82+
torch.bfloat16: tl.bfloat16,
83+
torch.float32: tl.float32
84+
}[test_output.dtype]
85+
invoke_moe_batched_triton_kernel(
86+
tensors.A,
87+
tensors.B,
88+
test_output,
89+
tensors.num_expert_tokens,
90+
compute_tl_dtype,
91+
# Quantization data
92+
None,
93+
None,
94+
None,
95+
# Quantization schemes
96+
False,
97+
False,
98+
False,
99+
config={
100+
"BLOCK_SIZE_M": 16,
101+
"BLOCK_SIZE_N": 16,
102+
"BLOCK_SIZE_K": 16
103+
})
104+
105+
ref_output = ref_impl(tensors.A, tensors.B, ref_output,
106+
tensors.num_expert_tokens)
107+
108+
rtol, atol = {
109+
torch.float16: (6e-2, 6e-2),
110+
torch.bfloat16: (6e-2, 6e-2),
111+
torch.float32: (1e-2, 1e-2),
112+
}[test_output.dtype]
113+
114+
torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol)

0 commit comments

Comments
 (0)