Skip to content

Commit 4f40568

Browse files
committed
fix weight config
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 3d750ab commit 4f40568

File tree

3 files changed

+12
-14
lines changed

3 files changed

+12
-14
lines changed

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def make(
361361
quant_dtype: Optional[torch.dtype] = None
362362

363363
input_quant = get_quant_config_input_quant(quant_config)
364-
weight_quant = get_quant_config_input_quant(quant_config)
364+
weight_quant = get_quant_config_weight_quant(quant_config)
365365

366366
if input_quant is not None:
367367
per_act_token_quant = (input_quant.strategy

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,7 @@ def apply(
800800
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
801801
A=intermediate_cache2,
802802
A_scale=a2_scale,
803-
quant_dtype=torch.float8_e4m3fn if self.use_fp8_w8a8 else None,
803+
quant_dtype=self.quant_dtype,
804804
per_act_token_quant=self.per_act_token_quant,
805805
block_shape=self.block_shape)
806806

vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def pplx_hidden_dim_scale_bytes(
3434
if per_act_token_quant:
3535
# per-token
3636
assert block_shape is None
37-
hidden_scale_bytes = max_num_tokens * elem_size
37+
hidden_scale_bytes = elem_size
3838
elif block_shape is not None:
3939
# per-group
4040
block_size = block_shape[1]
@@ -47,8 +47,10 @@ def pplx_hidden_dim_scale_bytes(
4747
hidden_dim_bytes = hidden_dim * in_dtype.itemsize
4848
hidden_scale_bytes = 0
4949

50-
return round_up(hidden_dim_bytes, align), round_up(hidden_scale_bytes,
51-
align)
50+
return (
51+
round_up(hidden_dim_bytes, align),
52+
round_up(hidden_scale_bytes, align),
53+
)
5254

5355

5456
# The max_num_tokens, world_size and dp_size must be the same
@@ -111,7 +113,7 @@ def prepare(
111113
a1 = a1 * topk_weights.to(a1.dtype)
112114

113115
repeat_cols = 4
114-
repeat_rows = 1 if quant_config.per_act_token_quant else a1.shape[0]
116+
repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0)
115117
a1q, a1q_scale = moe_kernel_quantize_input(
116118
a1, (None if quant_config.per_act_token_quant else a1_scale),
117119
quant_dtype=quant_config.quant_dtype,
@@ -146,16 +148,12 @@ def prepare(
146148

147149
expert_x_scale: Optional[torch.Tensor] = None
148150
if a1q.dtype.itemsize == 1:
149-
float32_size = torch.float32.itemsize
150151
block_size = (quant_config.block_shape[1]
151-
if quant_config.block_shape is not None else
152-
float32_size)
152+
if quant_config.block_shape is not None else 1)
153153
expert_x_scale = torch.empty(
154-
(
155-
num_local_experts,
156-
expert_x.size(1),
157-
(expert_x.size(2) + block_size - 1) // block_size,
158-
),
154+
(num_local_experts, expert_x.size(1),
155+
round_up(
156+
(expert_x.size(2) + block_size - 1) // block_size, 4)),
159157
dtype=torch.float32,
160158
device=device,
161159
)

0 commit comments

Comments
 (0)