Skip to content

Commit db773b0

Browse files
committed
fixes
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 6f9b1e7 commit db773b0

File tree

8 files changed

+33
-24
lines changed

8 files changed

+33
-24
lines changed

tests/kernels/moe/test_deepep_moe.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,20 +154,23 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
154154
deepep_ll_args = ll_args)
155155

156156
if low_latency_mode:
157+
# TODO(bnell): block_shape?
157158
fused_experts = BatchedTritonExperts(
158159
max_num_tokens=MAX_TOKENS_PER_RANK,
159160
world_size=pgi.world_size,
160161
dp_size=dp_size,
161162
use_fp8_w8a8=is_quantized,
162163
use_int8_w8a8=False,
163164
use_int8_w8a16=False,
164-
use_int4_w4a16=False)
165+
use_int4_w4a16=False,
166+
per_act_token_quant=False)
165167
else:
168+
# TODO(bnell): block_shape?
166169
fused_experts = TritonExperts(use_fp8_w8a8=is_quantized,
167170
use_int8_w8a8=False,
168171
use_int8_w8a16=False,
169172
use_int4_w4a16=False,
170-
per_channel_quant=False)
173+
per_act_token_quant=False)
171174

172175
mk = FusedMoEModularKernel(prepare_finalize=a2a,
173176
fused_experts=fused_experts)

tests/kernels/moe/test_pplx_cutlass_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def pplx_cutlass_moe(
9393
num_experts=num_experts,
9494
experts_per_token=topk,
9595
rank=rank,
96-
world_size=pgi.world_size,
96+
world_size=world_size,
9797
dp_size=dp_size,
9898
hidden_dim=hidden_dim,
9999
hidden_dim_bytes=hidden_dim, # because a.dtype.itemsize == 1

tests/kernels/moe/test_pplx_moe.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -429,11 +429,13 @@ def pplx_moe(
429429
dp_size,
430430
)
431431

432-
experts = BatchedTritonExperts(max_num_tokens=max_num_tokens,
433-
world_size=world_size,
434-
dp_size=dp_size,
435-
use_fp8_w8a8=qtype == torch.float8_e4m3fn,
436-
block_shape=block_shape)
432+
experts = BatchedTritonExperts(
433+
max_num_tokens=max_num_tokens,
434+
world_size=world_size,
435+
dp_size=dp_size,
436+
use_fp8_w8a8=qtype==torch.float8_e4m3fn,
437+
block_shape=block_shape
438+
)
437439

438440
fused_experts = FusedMoEModularKernel(
439441
prepare_finalize,

tests/kernels/moe/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ def batched_moe(
206206
dp_size=1,
207207
rank=0),
208208
BatchedTritonExperts(max_num_tokens=max_num_tokens,
209-
dp_size=1,
210209
world_size=1,
210+
dp_size=1,
211211
use_fp8_w8a8=qtype==torch.float8_e4m3fn,
212212
per_act_token_quant=per_act_token,
213213
block_shape=block_shape)

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def __init__(
219219
per_act_token_quant=per_act_token_quant,
220220
block_shape=block_shape,
221221
)
222+
assert max_experts_per_worker > 0
222223
self.max_experts_per_worker = max_experts_per_worker
223224
self.out_dtype = out_dtype
224225
self.per_out_ch_quant = per_out_ch_quant
@@ -249,7 +250,7 @@ def workspace_shapes(
249250
workspace1 = (M * topk, max(2 * N, K))
250251
workspace2 = (M * topk, N)
251252
output = (M * topk, K)
252-
return (workspace1, workspace2, output, self.out_dtype)
253+
return (workspace1, workspace2, output, self.out_dtype if self.out_dtype is not None else a.dtype)
253254

254255
def apply(
255256
self,
@@ -278,8 +279,9 @@ def apply(
278279
activation_callable, global_num_experts,
279280
expert_map, w1_scale, w2_scale, a1q_scale,
280281
a2_scale, workspace13, workspace2,
281-
expert_num_tokens, self.out_dtype,
282-
self.per_act_token, self.per_out_ch,
282+
expert_num_tokens,
283+
self.out_dtype if self.out_dtype is not None else hidden_states.dtype,
284+
self.per_act_token_quant, self.per_out_ch_quant,
283285
self.use_batched_format)
284286

285287

@@ -343,10 +345,12 @@ def cutlass_moe_fp8(
343345
if out_dtype is None:
344346
out_dtype = a.dtype
345347

348+
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0)
349+
346350
fn = mk.FusedMoEModularKernel(
347351
MoEPrepareAndFinalizeNoEP(),
348352
CutlassExpertsFp8(
349-
max_experts_per_worker=global_num_experts,
353+
max_experts_per_worker=num_experts,
350354
out_dtype=out_dtype,
351355
per_act_token_quant=per_act_token,
352356
per_out_ch_quant=per_out_ch,
@@ -362,7 +366,7 @@ def cutlass_moe_fp8(
362366
topk_ids,
363367
False,
364368
activation,
365-
global_num_experts if global_num_experts != -1 else w1_q.size(0),
369+
num_experts,
366370
expert_map,
367371
w1_scale,
368372
w2_scale,

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -799,12 +799,12 @@ def __init__(
799799
max_num_tokens: int,
800800
world_size: int,
801801
dp_size: int,
802-
use_fp8_w8a8: bool,
803-
use_int8_w8a8: bool,
804-
use_int8_w8a16: bool,
805-
use_int4_w4a16: bool,
806-
per_act_token_quant: bool,
807-
block_shape: Optional[list[int]],
802+
use_fp8_w8a8: bool = False,
803+
use_int8_w8a8: bool = False,
804+
use_int8_w8a16: bool = False,
805+
use_int4_w4a16: bool = False,
806+
per_act_token_quant: bool = False,
807+
block_shape: Optional[list[int]] = None,
808808
):
809809
quant_dtype = get_config_quant_dtype(
810810
use_fp8_w8a8=use_fp8_w8a8,

vllm/model_executor/layers/fused_moe/prepare_finalize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def prepare(
4343
a1.mul_(topk_weights.to(a1.dtype))
4444

4545
a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale,
46-
self.quant_dtype,
47-
self.per_act_token_quant,
48-
self.block_shape)
46+
quant_dtype,
47+
per_act_token_quant,
48+
block_shape)
4949

5050
return a1q, a1q_scale, None, None, None
5151

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ def select_gemm_impl(self, prepare_finalize, moe):
592592

593593
assert moe is not None
594594

595-
# method on prepare_finalize?
595+
# method on prepare_finalize? sketchy getting world_size from prepare_finalize
596596
max_experts_per_worker = (
597597
(moe.num_experts + prepare_finalize.world_size - 1) //
598598
prepare_finalize.world_size)

0 commit comments

Comments
 (0)