Skip to content

Commit 3d750ab

Browse files
committed
some quantization tweaks
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 337320f commit 3d750ab

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,7 @@ def prepare(
147147
# quantization. Fallback to per_token_dynamic quant.
148148
per_token_quant = True
149149
else:
150-
per_token_quant = ((quant_config.block_shape is None) or
151-
(a1_scale is not None and a1_scale.numel() != 1)
152-
or (a2_scale is not None
153-
and a2_scale.numel() != 1))
150+
per_token_quant = False
154151

155152
if per_token_quant:
156153
a1q, a1q_scale = moe_kernel_quantize_input(
@@ -160,6 +157,8 @@ def prepare(
160157
per_act_token_quant=True,
161158
block_shape=quant_config.block_shape,
162159
)
160+
if a1q_scale is not None and a1q_scale.numel() == 1:
161+
a1q_scale = a1q_scale.view(1, 1)
163162
(expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
164163
expert_topk_weights) = self._do_dispatch(
165164
tokens=a1q,

vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ def prepare(
119119
block_shape=quant_config.block_shape)
120120

121121
if a1q_scale is not None:
122+
if a1q_scale.numel() == 1:
123+
orig_a_scale_block_shape = 1
124+
else:
125+
orig_a_scale_block_shape = a1q_scale.shape[-1]
122126
a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols)
123127

124128
# rem_experts need to be 0 for pplx to work properly.
@@ -143,8 +147,9 @@ def prepare(
143147
expert_x_scale: Optional[torch.Tensor] = None
144148
if a1q.dtype.itemsize == 1:
145149
float32_size = torch.float32.itemsize
146-
block_size = (quant_config.block_shape[1] if quant_config.
147-
block_shape is not None else 1) * float32_size
150+
block_size = (quant_config.block_shape[1]
151+
if quant_config.block_shape is not None else
152+
float32_size)
148153
expert_x_scale = torch.empty(
149154
(
150155
num_local_experts,
@@ -169,7 +174,7 @@ def prepare(
169174
bound_m=bound_m,
170175
)
171176
if expert_x_scale is not None:
172-
expert_x_scale = expert_x_scale[:, :, 0:1]
177+
expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
173178

174179
return expert_x, expert_x_scale, expert_num_tokens, None, None
175180

0 commit comments

Comments
 (0)