Skip to content

Commit 78fe775

Browse files
authored
[Kernel] Enable fp8 support for pplx and BatchedTritonExperts. (#18864)
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 2f2fcb3 commit 78fe775

25 files changed

+1275
-661
lines changed

tests/kernels/moe/parallel_utils.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -137,16 +137,14 @@ def make_deepep_ht_a2a(pg: ProcessGroup,
137137
low_latency_mode=low_latency_mode,
138138
num_qps_per_rank=num_qps_per_rank)
139139
return DeepEPHTPrepareAndFinalize(buffer=buffer,
140-
world_size=pgi.world_size,
141-
rank=pgi.rank,
140+
num_dispatchers=pgi.world_size,
142141
dp_size=dp_size,
143142
rank_expert_offset=pgi.rank *
144143
ht_args.num_local_experts)
145144

146145

147146
def make_deepep_ll_a2a(pg: ProcessGroup,
148147
pgi: ProcessGroupInfo,
149-
dp_size: int,
150148
deepep_ll_args: DeepEPLLArgs,
151149
q_dtype: Optional[torch.dtype] = None,
152150
block_shape: Optional[list[int]] = None):
@@ -166,8 +164,7 @@ def make_deepep_ll_a2a(pg: ProcessGroup,
166164

167165
return DeepEPLLPrepareAndFinalize(
168166
buffer=buffer,
169-
world_size=pgi.world_size,
170-
dp_size=dp_size,
167+
num_dispatchers=pgi.world_size,
171168
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
172169
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
173170
)
@@ -186,5 +183,4 @@ def make_deepep_a2a(pg: ProcessGroup,
186183
block_shape)
187184

188185
assert deepep_ll_args is not None
189-
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype,
190-
block_shape)
186+
return make_deepep_ll_a2a(pg, pgi, deepep_ll_args, q_dtype, block_shape)

tests/kernels/moe/test_batched_moe.py

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from tests.kernels.moe.utils import (batched_moe,
1212
make_quantized_test_activations,
13-
make_test_weights, triton_moe)
13+
make_test_weights, naive_batched_moe)
1414
from tests.kernels.quant_utils import native_batched_masked_quant_matmul
1515
from tests.kernels.utils import torch_experts
1616
from vllm.config import VllmConfig, set_current_vllm_config
@@ -33,12 +33,10 @@
3333
(45, 512, 512),
3434
(45, 1024, 128),
3535
(45, 1024, 2048),
36-
(64, 128, 128),
3736
(64, 512, 512),
3837
(64, 1024, 2048),
3938
(222, 128, 128),
4039
(222, 128, 2048),
41-
(222, 512, 512),
4240
(222, 1024, 128),
4341
(222, 1024, 2048),
4442
]
@@ -95,11 +93,12 @@ def make_tensors(config: BatchedMMConfig):
9593
@pytest.mark.parametrize("max_tokens_per_expert",
9694
[32, 64, 128, 192, 224, 256, 512])
9795
@pytest.mark.parametrize("K", [128, 256, 1024])
98-
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
99-
@pytest.mark.parametrize("dtype",
100-
[torch.float32, torch.float16, torch.bfloat16])
101-
@pytest.mark.parametrize("block_shape", [None])
102-
@pytest.mark.parametrize("per_act_token_quant", [False])
96+
@pytest.mark.parametrize("N", [128, 256, 1024])
97+
@pytest.mark.parametrize(
98+
"dtype",
99+
[torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16])
100+
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
101+
@pytest.mark.parametrize("per_act_token_quant", [False, True])
103102
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
104103
N: int, dtype: torch.dtype,
105104
block_shape: Optional[list[int]],
@@ -134,7 +133,8 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
134133
in_dtype=act_dtype,
135134
quant_dtype=quant_dtype,
136135
block_shape=block_shape,
137-
per_act_token_quant=per_act_token_quant)
136+
per_act_token_quant=per_act_token_quant,
137+
)
138138

139139
B, B_q, B_scale, _, _, _ = make_test_weights(
140140
num_experts,
@@ -143,6 +143,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
143143
in_dtype=act_dtype,
144144
quant_dtype=quant_dtype,
145145
block_shape=block_shape,
146+
per_act_token_quant=per_act_token_quant,
146147
)
147148

148149
out_shape = (num_experts, max_tokens_per_expert, N)
@@ -177,6 +178,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
177178
"BLOCK_SIZE_N": 16,
178179
"BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32
179180
},
181+
per_act_token_quant=per_act_token_quant,
180182
block_shape=block_shape,
181183
)
182184

@@ -185,32 +187,31 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
185187
B,
186188
ref_output,
187189
num_expert_tokens,
188-
None,
189-
None,
190-
None,
191190
)
192191

193192
q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output,
194193
num_expert_tokens,
195194
A_scale, B_scale,
196-
block_shape)
195+
block_shape,
196+
per_act_token_quant)
197197

198198
rtol, atol = {
199199
torch.float16: (6e-2, 6e-2),
200200
torch.bfloat16: (6e-2, 6e-2),
201201
torch.float32: (1e-2, 1e-2),
202202
}[test_output.dtype]
203203

204-
torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol)
204+
torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol)
205205
torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)
206206

207207

208208
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
209209
@pytest.mark.parametrize("e", NUM_EXPERTS)
210210
@pytest.mark.parametrize("topk", TOP_KS)
211-
@pytest.mark.parametrize("dtype", [torch.bfloat16])
212-
@pytest.mark.parametrize("per_act_token_quant", [False])
213-
@pytest.mark.parametrize("block_shape", [None])
211+
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
212+
@pytest.mark.parametrize("per_act_token_quant", [False, True])
213+
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
214+
@pytest.mark.parametrize("input_scales", [False])
214215
def test_fused_moe_batched_experts(
215216
m: int,
216217
n: int,
@@ -220,15 +221,19 @@ def test_fused_moe_batched_experts(
220221
dtype: torch.dtype,
221222
per_act_token_quant: bool,
222223
block_shape: Optional[list[int]],
224+
input_scales: bool,
223225
):
224226
current_platform.seed_everything(7)
225227

226228
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
227229

230+
if topk > e:
231+
pytest.skip("topk > e")
232+
228233
if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
229234
pytest.skip("Skip quantization test for non-quantized type")
230235

231-
if per_act_token_quant and block_shape is not None or topk > e:
236+
if per_act_token_quant and block_shape is not None:
232237
pytest.skip("Skip illegal quantization test.")
233238

234239
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
@@ -241,55 +246,74 @@ def test_fused_moe_batched_experts(
241246
act_dtype = dtype
242247
quant_dtype = None
243248

244-
_, w1, w1_s, _, w2, w2_s = make_test_weights(e,
245-
n,
246-
k,
247-
block_shape=block_shape,
248-
in_dtype=act_dtype,
249-
quant_dtype=quant_dtype)
249+
w1_16, w1, w1_s, w2_16, w2, w2_s = make_test_weights(
250+
e,
251+
n,
252+
k,
253+
block_shape=block_shape,
254+
in_dtype=act_dtype,
255+
quant_dtype=quant_dtype,
256+
per_act_token_quant=per_act_token_quant,
257+
)
258+
259+
if input_scales and quant_dtype is not None:
260+
a1_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
261+
a2_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
262+
else:
263+
a1_scale = None
264+
a2_scale = None
250265

251266
with set_current_vllm_config(vllm_config):
252267
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
253-
batched_output = batched_moe(
268+
269+
baseline_output = torch_experts(
254270
a,
255271
w1,
256272
w2,
257273
topk_weight,
258274
topk_ids,
259275
w1_scale=w1_s,
260276
w2_scale=w2_s,
277+
a1_scale=a1_scale,
278+
a2_scale=a2_scale,
261279
quant_dtype=quant_dtype,
262280
per_act_token_quant=per_act_token_quant,
263281
block_shape=block_shape,
264282
)
265-
baseline_output = torch_experts(
283+
284+
batched_output = naive_batched_moe(
266285
a,
267286
w1,
268287
w2,
269288
topk_weight,
270289
topk_ids,
271290
w1_scale=w1_s,
272291
w2_scale=w2_s,
292+
a1_scale=a1_scale,
293+
a2_scale=a2_scale,
273294
quant_dtype=quant_dtype,
274295
per_act_token_quant=per_act_token_quant,
275-
block_shape=block_shape)
296+
block_shape=block_shape,
297+
)
276298

277-
triton_output = triton_moe(
299+
triton_output = batched_moe(
278300
a,
279301
w1,
280302
w2,
281303
topk_weight,
282304
topk_ids,
283305
w1_scale=w1_s,
284306
w2_scale=w2_s,
307+
a1_scale=a1_scale,
308+
a2_scale=a2_scale,
285309
quant_dtype=quant_dtype,
286310
per_act_token_quant=per_act_token_quant,
287311
block_shape=block_shape,
288312
)
289313

290-
torch.testing.assert_close(triton_output,
314+
torch.testing.assert_close(batched_output,
291315
baseline_output,
292-
atol=2e-2,
316+
atol=3e-2,
293317
rtol=2e-2)
294318

295319
torch.testing.assert_close(triton_output,

tests/kernels/moe/test_deepep_deepgemm_moe.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,7 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
148148

149149
fused_experts = BatchedDeepGemmExperts(
150150
max_num_tokens=max_tokens_per_rank,
151-
world_size=pgi.world_size,
152-
dp_size=dp_size,
151+
num_dispatchers=pgi.world_size // dp_size,
153152
block_shape=test_config.block_size,
154153
per_act_token_quant=test_config.per_act_token_quant)
155154
mk = FusedMoEModularKernel(prepare_finalize=a2a,

tests/kernels/moe/test_deepep_moe.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,13 @@ def make_modular_kernel(
154154
deepep_ht_args = ht_args,
155155
deepep_ll_args = ll_args)
156156

157+
num_dispatchers = pgi.world_size // dp_size
158+
157159
if low_latency_mode:
158160
assert not per_act_token_quant, "not supported in ll mode"
159161
fused_experts = BatchedTritonExperts(
160162
max_num_tokens=MAX_TOKENS_PER_RANK,
161-
world_size=pgi.world_size,
162-
dp_size=dp_size,
163+
num_dispatchers=num_dispatchers,
163164
use_fp8_w8a8=is_quantized,
164165
use_int8_w8a8=False,
165166
use_int8_w8a16=False,

tests/kernels/moe/test_pplx_cutlass_moe.py

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from vllm.model_executor.layers.fused_moe.modular_kernel import (
1515
FusedMoEModularKernel)
1616
from vllm.platforms import current_platform
17+
from vllm.utils import cdiv
1718

1819
from .parallel_utils import ProcessGroupInfo, parallel_launch
1920

@@ -112,18 +113,21 @@ def pplx_cutlass_moe(
112113
w2_scale = w2_scale.to(device)
113114
a1_scale = a1_scale.to(device)
114115

116+
assert num_experts % world_size == 0
117+
num_local_experts = cdiv(num_experts, world_size)
118+
num_dispatchers = pgi.world_size // dp_size
119+
115120
prepare_finalize = PplxPrepareAndFinalize(
116121
ata,
117-
max_num_tokens,
118-
pgi.world_size,
119-
rank,
120-
dp_size,
121-
)
122+
max_num_tokens=max_num_tokens,
123+
num_local_experts=num_local_experts,
124+
num_dispatchers=num_dispatchers)
122125

123-
experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size,
126+
experts = CutlassExpertsFp8(num_local_experts,
124127
out_dtype,
125128
per_act_token,
126129
per_out_ch,
130+
num_dispatchers=num_dispatchers,
127131
use_batched_format=True)
128132

129133
fused_cutlass_experts = FusedMoEModularKernel(
@@ -181,35 +185,40 @@ def _pplx_moe(
181185
per_out_ch: bool,
182186
use_internode: bool,
183187
):
184-
if use_internode:
185-
uid = nvshmem_get_unique_id(
186-
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
187-
torch.distributed.broadcast(uid, src=0)
188-
nvshmem_init(uid, pgi.rank, pgi.world_size)
189-
else:
190-
group_ranks = list(range(pgi.world_size))
191-
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
192-
group_name = cpu_group.group_name
193-
194-
with set_current_vllm_config(vllm_config):
195-
torch_output = torch_experts(a_full, w1_full, w2_full, topk_weights,
196-
topk_ids)
197-
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
198-
w2_scale, topk_weights, topk_ids,
199-
a1_scale, out_dtype, per_act_token,
200-
per_out_ch, group_name)
201-
202-
torch_output = chunk_by_rank(torch_output, pgi.rank,
203-
pgi.world_size).to(pplx_output.device)
204-
205-
# Uncomment if more debugging is needed
206-
# print("PPLX OUT:", pplx_output)
207-
# print("TORCH OUT:", torch_output)
208-
209-
torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0)
210-
211-
if use_internode:
212-
nvshmem_finalize()
188+
try:
189+
if use_internode:
190+
uid = nvshmem_get_unique_id(
191+
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
192+
torch.distributed.broadcast(uid, src=0)
193+
nvshmem_init(uid, pgi.rank, pgi.world_size)
194+
else:
195+
group_ranks = list(range(pgi.world_size))
196+
cpu_group = torch.distributed.new_group(group_ranks,
197+
backend="gloo")
198+
group_name = cpu_group.group_name
199+
200+
with set_current_vllm_config(vllm_config):
201+
torch_output = torch_experts(a_full, w1_full, w2_full,
202+
topk_weights, topk_ids)
203+
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
204+
w2_scale, topk_weights, topk_ids,
205+
a1_scale, out_dtype, per_act_token,
206+
per_out_ch, group_name)
207+
208+
torch_output = chunk_by_rank(torch_output, pgi.rank,
209+
pgi.world_size).to(pplx_output.device)
210+
211+
# Uncomment if more debugging is needed
212+
# print("PPLX OUT:", pplx_output)
213+
# print("TORCH OUT:", torch_output)
214+
215+
torch.testing.assert_close(pplx_output,
216+
torch_output,
217+
atol=0.05,
218+
rtol=0)
219+
finally:
220+
if use_internode:
221+
nvshmem_finalize()
213222

214223

215224
@pytest.mark.parametrize("m", [2, 224])

0 commit comments

Comments
 (0)