Skip to content

[Kernel] Enable fp8 support for pplx and BatchedTritonExperts. #18864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 77 commits into from
Jul 3, 2025
Merged
Show file tree
Hide file tree
Changes from 72 commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
dd40b1e
fp8 support
bnellnm May 21, 2025
a051108
wip
bnellnm May 21, 2025
43f9cfe
test
bnellnm May 21, 2025
39a2ab3
basic working test
bnellnm May 21, 2025
b5996ec
tests + fix
bnellnm May 22, 2025
356d4d7
stuff
bnellnm May 27, 2025
ad55ba1
cleanup quantization
bnellnm May 28, 2025
347f58e
merge
bnellnm May 28, 2025
035d324
fix merge
bnellnm May 28, 2025
ac46906
lint
bnellnm May 28, 2025
1a5d6b3
fixes
bnellnm May 29, 2025
29e314c
pplx + fp8 test
bnellnm May 29, 2025
cd5bc8f
fp8 + pplx tests + fixes
bnellnm May 29, 2025
037eb4a
re-enable cudagraph+torch.compile
bnellnm May 30, 2025
43dd36b
hacks
bnellnm May 30, 2025
a778f5a
clean up quantization parameters
bnellnm May 31, 2025
0ec77aa
lint
bnellnm May 31, 2025
c6a4451
progress on grouped quant for batched experts
bnellnm Jun 1, 2025
faa9b2f
wip
bnellnm Jun 2, 2025
985ce2e
triton + debug hacking
bnellnm Jun 2, 2025
4d114ee
batched mm tests with real scales + grouped quant
bnellnm Jun 3, 2025
8a019e2
hacking on tests
bnellnm Jun 4, 2025
dceee15
scale hacking
bnellnm Jun 4, 2025
d6eda9b
wip hacking
bnellnm Jun 6, 2025
f554eda
cleanup ctor args
bnellnm Jun 10, 2025
8e70e60
wip
bnellnm Jun 11, 2025
5f5e9a3
fixes
bnellnm Jun 11, 2025
9d30bcc
refactoring
bnellnm Jun 12, 2025
9fd5833
lint
bnellnm Jun 12, 2025
16a4d7f
lint
bnellnm Jun 12, 2025
c21e4df
lint
bnellnm Jun 12, 2025
82a0b1e
lint
bnellnm Jun 12, 2025
39c9b5e
fix merge. split up int8/fp8 moe tests
bnellnm Jun 13, 2025
c036763
wip
bnellnm Jun 18, 2025
861500e
cleanup
bnellnm Jun 20, 2025
ae39492
cleanup
bnellnm Jun 20, 2025
ae45963
fixes after merge
bnellnm Jun 21, 2025
b783ce6
torch_experts working
bnellnm Jun 21, 2025
3ea1454
fp8 baselines working
bnellnm Jun 22, 2025
9d8dd1d
mm baselines work
bnellnm Jun 22, 2025
5b376e5
prepare_finalize wokring
bnellnm Jun 23, 2025
0c06d4b
per token + grouped broken
bnellnm Jun 24, 2025
8d4e287
a scales working, b scales not working
bnellnm Jun 24, 2025
62404e3
blocked working
bnellnm Jun 24, 2025
b4dc46e
per_act_token working
bnellnm Jun 24, 2025
c3bddec
qwen works, rh-ds broken now, pplx_moe tests not all working
bnellnm Jun 25, 2025
185f090
both models work
bnellnm Jun 25, 2025
946a950
cleanup
bnellnm Jun 26, 2025
d8a2723
lint
bnellnm Jun 26, 2025
e40e9c0
fix test
bnellnm Jun 27, 2025
79e8d6b
fixes
bnellnm Jun 27, 2025
c02faca
fix pplx tests, fix indices type assert
bnellnm Jun 27, 2025
47eaa19
fixes
bnellnm Jun 27, 2025
8f3ee3a
fix lint
bnellnm Jun 27, 2025
1f15b73
fix per_act_token in pplx
bnellnm Jun 28, 2025
7d891cd
cleanups
bnellnm Jun 28, 2025
ca748ed
use proper experts for test_pplx_moe, naive experts work
bnellnm Jun 28, 2025
5eceb6d
fix test flag
bnellnm Jun 28, 2025
9b92fee
re-enable tests + loopify test_pplx_moe tests
bnellnm Jun 28, 2025
8894d0f
add optional tag to slow tests
bnellnm Jun 29, 2025
d135b41
fix lint
bnellnm Jun 29, 2025
35966a0
tweaks
bnellnm Jul 1, 2025
8485bde
fixes
bnellnm Jul 1, 2025
70fa1dd
fixup world_size/dp_size params
bnellnm Jul 1, 2025
9c56206
fix tests
bnellnm Jul 1, 2025
653942f
more test fixes
bnellnm Jul 1, 2025
d2dd405
fix merge
bnellnm Jul 1, 2025
ae91a5e
trim testcases
bnellnm Jul 2, 2025
76c697a
fix lint
bnellnm Jul 2, 2025
a5c8e85
ping
bnellnm Jul 2, 2025
285b2bc
fix num_dispatchers for TP+DP
bnellnm Jul 2, 2025
286d988
fix unit test
bnellnm Jul 2, 2025
14542e5
review comments
bnellnm Jul 2, 2025
2a96289
remove debug cruft
bnellnm Jul 2, 2025
562bb3e
review comments + scout fix
bnellnm Jul 2, 2025
a9b0730
remove bogus assert
bnellnm Jul 3, 2025
b37026d
scout fixes
bnellnm Jul 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions tests/kernels/moe/parallel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def make_deepep_ht_a2a(pg: ProcessGroup,
low_latency_mode=low_latency_mode,
num_qps_per_rank=num_qps_per_rank)
return DeepEPHTPrepareAndFinalize(buffer=buffer,
world_size=pgi.world_size,
num_dispatchers=pgi.world_size,
rank=pgi.rank,
dp_size=dp_size,
rank_expert_offset=pgi.rank *
Expand All @@ -146,7 +146,6 @@ def make_deepep_ht_a2a(pg: ProcessGroup,

def make_deepep_ll_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
deepep_ll_args: DeepEPLLArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
Expand All @@ -166,8 +165,7 @@ def make_deepep_ll_a2a(pg: ProcessGroup,

return DeepEPLLPrepareAndFinalize(
buffer=buffer,
world_size=pgi.world_size,
dp_size=dp_size,
num_dispatchers=pgi.world_size,
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
)
Expand All @@ -186,5 +184,4 @@ def make_deepep_a2a(pg: ProcessGroup,
block_shape)

assert deepep_ll_args is not None
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype,
block_shape)
return make_deepep_ll_a2a(pg, pgi, deepep_ll_args, q_dtype, block_shape)
84 changes: 54 additions & 30 deletions tests/kernels/moe/test_batched_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from tests.kernels.moe.utils import (batched_moe,
make_quantized_test_activations,
make_test_weights, triton_moe)
make_test_weights, naive_batched_moe)
from tests.kernels.quant_utils import native_batched_masked_quant_matmul
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config
Expand All @@ -33,12 +33,10 @@
(45, 512, 512),
(45, 1024, 128),
(45, 1024, 2048),
(64, 128, 128),
(64, 512, 512),
(64, 1024, 2048),
(222, 128, 128),
(222, 128, 2048),
(222, 512, 512),
(222, 1024, 128),
(222, 1024, 2048),
]
Expand Down Expand Up @@ -95,11 +93,12 @@ def make_tensors(config: BatchedMMConfig):
@pytest.mark.parametrize("max_tokens_per_expert",
[32, 64, 128, 192, 224, 256, 512])
@pytest.mark.parametrize("K", [128, 256, 1024])
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("block_shape", [None])
@pytest.mark.parametrize("per_act_token_quant", [False])
@pytest.mark.parametrize("N", [128, 256, 1024])
@pytest.mark.parametrize(
"dtype",
[torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
N: int, dtype: torch.dtype,
block_shape: Optional[list[int]],
Expand Down Expand Up @@ -134,7 +133,8 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
in_dtype=act_dtype,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant)
per_act_token_quant=per_act_token_quant,
)

B, B_q, B_scale, _, _, _ = make_test_weights(
num_experts,
Expand All @@ -143,6 +143,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
in_dtype=act_dtype,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
)

out_shape = (num_experts, max_tokens_per_expert, N)
Expand Down Expand Up @@ -177,6 +178,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32
},
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)

Expand All @@ -185,32 +187,31 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
B,
ref_output,
num_expert_tokens,
None,
None,
None,
)

q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output,
num_expert_tokens,
A_scale, B_scale,
block_shape)
block_shape,
per_act_token_quant)

rtol, atol = {
torch.float16: (6e-2, 6e-2),
torch.bfloat16: (6e-2, 6e-2),
torch.float32: (1e-2, 1e-2),
}[test_output.dtype]

torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol)
torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol)
torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)


@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("per_act_token_quant", [False])
@pytest.mark.parametrize("block_shape", [None])
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("input_scales", [False])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this only False?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've left it here for future testing,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Should there be also a condition in the test code to skip the test if input_scales == True and quant_dtype is None?

Copy link
Contributor Author

@bnellnm bnellnm Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's one of the conditions that needs more testing. There's some int8/int4 quantization schemes that happen outside the triton kernels. So they need to pass in the quantized data + scales, but no quant_type since they are already quantized.

def test_fused_moe_batched_experts(
m: int,
n: int,
Expand All @@ -220,15 +221,19 @@ def test_fused_moe_batched_experts(
dtype: torch.dtype,
per_act_token_quant: bool,
block_shape: Optional[list[int]],
input_scales: bool,
):
current_platform.seed_everything(7)

use_fp8_w8a8 = dtype == torch.float8_e4m3fn

if topk > e:
pytest.skip("topk > e")

if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
pytest.skip("Skip quantization test for non-quantized type")

if per_act_token_quant and block_shape is not None or topk > e:
if per_act_token_quant and block_shape is not None:
pytest.skip("Skip illegal quantization test.")

a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
Expand All @@ -241,55 +246,74 @@ def test_fused_moe_batched_experts(
act_dtype = dtype
quant_dtype = None

_, w1, w1_s, _, w2, w2_s = make_test_weights(e,
n,
k,
block_shape=block_shape,
in_dtype=act_dtype,
quant_dtype=quant_dtype)
w1_16, w1, w1_s, w2_16, w2, w2_s = make_test_weights(
e,
n,
k,
block_shape=block_shape,
in_dtype=act_dtype,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
)

if input_scales and quant_dtype is not None:
a1_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
a2_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
else:
a1_scale = None
a2_scale = None

with set_current_vllm_config(vllm_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
batched_output = batched_moe(

baseline_output = torch_experts(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
a2_scale=a2_scale,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
baseline_output = torch_experts(

batched_output = naive_batched_moe(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
a2_scale=a2_scale,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape)
block_shape=block_shape,
)

triton_output = triton_moe(
triton_output = batched_moe(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
a2_scale=a2_scale,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)

torch.testing.assert_close(triton_output,
torch.testing.assert_close(batched_output,
baseline_output,
atol=2e-2,
atol=3e-2,
rtol=2e-2)

torch.testing.assert_close(triton_output,
Expand Down
3 changes: 1 addition & 2 deletions tests/kernels/moe/test_deepep_deepgemm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,7 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,

fused_experts = BatchedDeepGemmExperts(
max_num_tokens=max_tokens_per_rank,
world_size=pgi.world_size,
dp_size=dp_size,
num_dispatchers=pgi.world_size // dp_size,
block_shape=test_config.block_size,
per_act_token_quant=test_config.per_act_token_quant)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
Expand Down
5 changes: 3 additions & 2 deletions tests/kernels/moe/test_deepep_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,13 @@ def make_modular_kernel(
deepep_ht_args = ht_args,
deepep_ll_args = ll_args)

num_dispatchers = pgi.world_size // dp_size

if low_latency_mode:
assert not per_act_token_quant, "not supported in ll mode"
fused_experts = BatchedTritonExperts(
max_num_tokens=MAX_TOKENS_PER_RANK,
world_size=pgi.world_size,
dp_size=dp_size,
num_dispatchers=num_dispatchers,
use_fp8_w8a8=is_quantized,
use_int8_w8a8=False,
use_int8_w8a16=False,
Expand Down
79 changes: 44 additions & 35 deletions tests/kernels/moe/test_pplx_cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.platforms import current_platform
from vllm.utils import cdiv

from .parallel_utils import ProcessGroupInfo, parallel_launch

Expand Down Expand Up @@ -112,18 +113,21 @@ def pplx_cutlass_moe(
w2_scale = w2_scale.to(device)
a1_scale = a1_scale.to(device)

assert num_experts % world_size == 0
num_local_experts = cdiv(num_experts, world_size)
num_dispatchers = pgi.world_size // dp_size

prepare_finalize = PplxPrepareAndFinalize(
ata,
max_num_tokens,
pgi.world_size,
rank,
dp_size,
)
max_num_tokens=max_num_tokens,
num_local_experts=num_local_experts,
num_dispatchers=num_dispatchers)

experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size,
experts = CutlassExpertsFp8(num_local_experts,
out_dtype,
per_act_token,
per_out_ch,
num_dispatchers=num_dispatchers,
use_batched_format=True)

fused_cutlass_experts = FusedMoEModularKernel(
Expand Down Expand Up @@ -181,35 +185,40 @@ def _pplx_moe(
per_out_ch: bool,
use_internode: bool,
):
if use_internode:
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
else:
group_ranks = list(range(pgi.world_size))
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
group_name = cpu_group.group_name

with set_current_vllm_config(vllm_config):
torch_output = torch_experts(a_full, w1_full, w2_full, topk_weights,
topk_ids)
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
w2_scale, topk_weights, topk_ids,
a1_scale, out_dtype, per_act_token,
per_out_ch, group_name)

torch_output = chunk_by_rank(torch_output, pgi.rank,
pgi.world_size).to(pplx_output.device)

# Uncomment if more debugging is needed
# print("PPLX OUT:", pplx_output)
# print("TORCH OUT:", torch_output)

torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0)

if use_internode:
nvshmem_finalize()
try:
if use_internode:
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
else:
group_ranks = list(range(pgi.world_size))
cpu_group = torch.distributed.new_group(group_ranks,
backend="gloo")
group_name = cpu_group.group_name

with set_current_vllm_config(vllm_config):
torch_output = torch_experts(a_full, w1_full, w2_full,
topk_weights, topk_ids)
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
w2_scale, topk_weights, topk_ids,
a1_scale, out_dtype, per_act_token,
per_out_ch, group_name)

torch_output = chunk_by_rank(torch_output, pgi.rank,
pgi.world_size).to(pplx_output.device)

# Uncomment if more debugging is needed
# print("PPLX OUT:", pplx_output)
# print("TORCH OUT:", torch_output)

torch.testing.assert_close(pplx_output,
torch_output,
atol=0.05,
rtol=0)
finally:
if use_internode:
nvshmem_finalize()


@pytest.mark.parametrize("m", [2, 224])
Expand Down
Loading