Skip to content

Commit 680ecc5

Browse files
committed
fixes
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 2401f70 commit 680ecc5

File tree

4 files changed

+10
-6
lines changed

4 files changed

+10
-6
lines changed

tests/kernels/moe/test_deepep_deepgemm_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
207207
fused_experts = BatchedDeepGemmExperts(max_num_tokens=max_tokens_per_rank,
208208
world_size=pgi.world_size,
209209
dp_size=dp_size,
210-
block_shape=test_config.block_size)
210+
block_shape=test_config.block_size,
211+
per_act_token_quant=True)
211212
mk = FusedMoEModularKernel(prepare_finalize=a2a,
212213
fused_experts=fused_experts)
213214
return mk

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
2121
DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128]
2222

2323
def __init__(self, max_num_tokens: int, world_size: int, dp_size: int,
24-
block_shape: list[int]):
24+
block_shape: list[int],
25+
per_act_token_quant=False):
2526
"""
2627
max_num_tokens: Maximum number of tokens from a DP Rank
2728
world_size: Number of EP ranks
@@ -31,7 +32,7 @@ def __init__(self, max_num_tokens: int, world_size: int, dp_size: int,
3132
super().__init__(
3233
FusedMoEQuantConfig(
3334
quant_dtype=torch.float8_e4m3fn,
34-
per_act_token_quant=False,
35+
per_act_token_quant=per_act_token_quant,
3536
block_shape=block_shape,
3637
))
3738
assert self.block_shape == self.DEEPGEMM_BLOCK_SHAPE

vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,14 @@ def _do_quant(
9494
]) and quant_dtype is not None:
9595
# Quantization required despite none of the inputs suggesting
9696
# quantization. Fallback to per_token_dynamic quant.
97+
#print(f"DYNAMIC")
9798
_per_act_token_quant = True
9899
else:
99100
_per_act_token_quant = ((block_shape is not None) or
100101
(a1_scale is not None and a1_scale.numel() != 1)
101102
or (a2_scale is not None
102103
and a2_scale.numel() != 1))
104+
#print(f"{block_shape} {a1_scale} {a2_scale}")
103105

104106
# assert per_act_token_quant == (
105107
# (block_shape is not None)
@@ -108,7 +110,7 @@ def _do_quant(
108110

109111

110112
# TODO(bnell)
111-
#assert per_act_token_quant == _per_act_token_quant
113+
assert per_act_token_quant == _per_act_token_quant, f"{per_act_token_quant} == {_per_act_token_quant}"
112114

113115
num_experts, max_tokens, hidden_dim = x.size()
114116

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,8 @@ def invoke_moe_batched_triton_kernel(
318318
expert_num_tokens: torch.Tensor, # [E]
319319
compute_type: tl.dtype,
320320
# Quantization data
321-
A_scale: Optional[torch.Tensor],
322-
B_scale: Optional[torch.Tensor],
321+
A_scale: torch.Tensor, # Optional
322+
B_scale: torch.Tensor, # Optional
323323
B_zp: torch.Tensor,
324324
# Quantization schemes
325325
use_fp8_w8a8: bool,

0 commit comments

Comments
 (0)