Skip to content

Commit 1c89788

Browse files
committed
fix deepep ht tests
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 56ebd31 commit 1c89788

File tree

2 files changed

+12
-13
lines changed

2 files changed

+12
-13
lines changed

tests/kernels/moe/test_deepep_deepgemm_moe.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
1818
from vllm.model_executor.layers.fused_moe.modular_kernel import (
1919
FusedMoEModularKernel)
20-
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
21-
per_token_group_quant_fp8)
2220
from vllm.platforms import current_platform
2321
from vllm.utils import has_deep_ep, has_deep_gemm
2422

@@ -81,6 +79,7 @@ class TestConfig:
8179
k: int
8280
n: int
8381
num_experts: int
82+
per_act_token_quant: bool
8483
block_size: list[int]
8584
# configs for testing low-latency kernels
8685
low_latency: bool
@@ -99,18 +98,15 @@ class TestTensors:
9998
def make(config: TestConfig, rank) -> "TestTensors":
10099

101100
dtype = torch.bfloat16
102-
topk, m, k, block_size = (config.topk, config.m, config.k,
103-
config.block_size)
101+
topk, m, k = (config.topk, config.m, config.k)
104102

105103
fp8_info = torch.finfo(torch.float8_e4m3fn)
106104
fp8_max, fp8_min = fp8_info.max, fp8_info.min
107105

108106
rank_tokens = torch.randn(
109107
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
110108
rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
111-
112-
block_k = block_size[1]
113-
_, rank_token_scales = per_token_group_quant_fp8(rank_tokens, block_k)
109+
rank_token_scales = None
114110

115111
topk_ids = torch.randint(
116112
low=0,
@@ -150,11 +146,12 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
150146
q_dtype=q_dtype,
151147
block_shape=test_config.block_size)
152148

153-
fused_experts = BatchedDeepGemmExperts(max_num_tokens=max_tokens_per_rank,
154-
world_size=pgi.world_size,
155-
dp_size=dp_size,
156-
block_shape=test_config.block_size,
157-
per_act_token_quant=False)
149+
fused_experts = BatchedDeepGemmExperts(
150+
max_num_tokens=max_tokens_per_rank,
151+
world_size=pgi.world_size,
152+
dp_size=dp_size,
153+
block_shape=test_config.block_size,
154+
per_act_token_quant=test_config.per_act_token_quant)
158155
mk = FusedMoEModularKernel(prepare_finalize=a2a,
159156
fused_experts=fused_experts)
160157
return mk
@@ -393,6 +390,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
393390
k=k,
394391
n=n,
395392
num_experts=num_experts,
393+
per_act_token_quant=False,
396394
block_size=block_size,
397395
low_latency=False,
398396
use_fp8_dispatch=None)
@@ -450,6 +448,7 @@ def test_ll_deepep_deepgemm_moe(
450448
k=k,
451449
n=n,
452450
num_experts=num_experts,
451+
per_act_token_quant=False,
453452
block_size=block_size,
454453
low_latency=True,
455454
use_fp8_dispatch=use_fp8_dispatch,

vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def prepare(
147147
# quantization. Fallback to per_token_dynamic quant.
148148
per_token_quant = True
149149
else:
150-
per_token_quant = ((quant_config.block_shape is not None) or
150+
per_token_quant = ((quant_config.block_shape is None) or
151151
(a1_scale is not None and a1_scale.numel() != 1)
152152
or (a2_scale is not None
153153
and a2_scale.numel() != 1))

0 commit comments

Comments
 (0)