Skip to content

Commit c20591f

Browse files
committed
cleanup ctor args
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent addd937 commit c20591f

18 files changed

+240
-250
lines changed

tests/kernels/moe/test_pplx_moe.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ def pplx_moe(
393393
max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 64)
394394

395395
hidden_dim_bytes, scale_bytes = pplx_hidden_dim_scale_bytes(
396+
max_num_tokens,
396397
hidden_dim,
397398
a.dtype,
398399
qtype,
@@ -426,9 +427,6 @@ def pplx_moe(
426427
world_size,
427428
rank,
428429
dp_size,
429-
quant_dtype=qtype,
430-
per_act_token_quant=per_act_token_quant,
431-
block_shape=block_shape,
432430
)
433431

434432
experts = BatchedTritonExperts(max_num_tokens=max_num_tokens,

tests/kernels/moe/utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -204,14 +204,11 @@ def batched_moe(
204204
BatchedPrepareAndFinalize(max_num_tokens,
205205
world_size=1,
206206
dp_size=1,
207-
rank=0,
208-
quant_dtype=qtype,
209-
block_shape=block_shape,
210-
per_act_token_quant=per_act_token),
207+
rank=0),
211208
BatchedTritonExperts(max_num_tokens=max_num_tokens,
212209
dp_size=1,
213210
world_size=1,
214-
use_fp8_w8a8=qtype == torch.float8_e4m3fn,
211+
use_fp8_w8a8=qtype==torch.float8_e4m3fn,
215212
per_act_token_quant=per_act_token,
216213
block_shape=block_shape)
217214
)

vllm/model_executor/layers/fused_moe/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Any, Optional
66

77
from vllm.model_executor.layers.fused_moe.layer import (
8-
MOE_DP_CHUNK_SIZE, FusedMoE, FusedMoEMethodBase,
8+
FusedMoE, FusedMoEMethodBase,
99
FusedMoeWeightScaleSupported)
1010
from vllm.triton_utils import HAS_TRITON
1111

@@ -31,7 +31,6 @@ def get_config() -> Optional[dict[str, Any]]:
3131
"FusedMoeWeightScaleSupported",
3232
"override_config",
3333
"get_config",
34-
"MOE_DP_CHUNK_SIZE",
3534
]
3635

3736
if HAS_TRITON:

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,29 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
1919
# The Deep Gemm kernels only support block size of 128
2020
DEEPGEMM_BLOCK_SHAPE = 128
2121

22-
def __init__(self, max_num_tokens: int, world_size: int, dp_size: int,
23-
block_shape: list[int]):
22+
def __init__(
23+
self,
24+
max_num_tokens: int,
25+
world_size: int,
26+
dp_size: int,
27+
block_shape: list[int]
28+
):
2429
"""
2530
max_num_tokens: Maximum number of tokens from a DP Rank
2631
world_size: Number of EP ranks
2732
dp_size: Number of data-parallel ranks
2833
block_shape: Block quantization block shape
2934
"""
30-
super().__init__()
35+
36+
assert self.block_shape == [self.DEEPGEMM_BLOCK_SHAPE, self.DEEPGEMM_BLOCK_SHAPE]
37+
super().__init__(
38+
quant_dtype=torch.float8_e4m3fn,
39+
block_shape=block_shape,
40+
per_act_token_quant=False,
41+
)
3142
self.max_num_tokens = max_num_tokens
3243
self.world_size = world_size
3344
self.dp_size = dp_size
34-
self.block_shape = block_shape
35-
36-
assert (len(self.block_shape) == 2 and all(
37-
[v == self.DEEPGEMM_BLOCK_SHAPE for v in self.block_shape]))
3845

3946
def supports_chunking(self) -> bool:
4047
return False

vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py

Lines changed: 49 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,57 +12,68 @@
1212

1313
class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
1414

15-
def __init__(self,
16-
max_num_tokens: int,
17-
world_size: int,
18-
dp_size: int,
19-
use_fp8_w8a8: bool = False,
20-
use_int8_w8a8: bool = False,
21-
use_int8_w8a16: bool = False,
22-
use_int4_w4a16: bool = False,
23-
per_channel_quant: bool = False,
24-
block_shape: Optional[list[int]] = None,
25-
allow_deep_gemm: bool = False):
26-
super().__init__()
15+
def __init__(
16+
self,
17+
max_num_tokens: int,
18+
world_size: int,
19+
dp_size: int,
20+
use_fp8_w8a8: bool = False,
21+
use_int8_w8a8: bool = False,
22+
use_int8_w8a16: bool = False,
23+
use_int4_w4a16: bool = False,
24+
block_shape: Optional[list[int]] = None,
25+
per_act_token_quant: bool = False,
26+
allow_deep_gemm: bool = False
27+
):
28+
from vllm.model_executor.layers.fused_moe.fused_moe import (
29+
get_config_quant_dtype)
30+
2731
assert not use_int8_w8a8, "NYI"
2832
assert not use_int8_w8a16, "NYI"
2933
assert not use_int4_w4a16, "NYI"
3034

35+
quant_dtype = get_config_quant_dtype(
36+
use_fp8_w8a8=use_fp8_w8a8,
37+
use_int8_w8a8=use_int8_w8a8,
38+
use_int8_w8a16=use_int8_w8a16,
39+
use_int4_w4a16=use_int4_w4a16,
40+
)
41+
super().__init__(
42+
quant_dtype=quant_dtype,
43+
block_shape=block_shape,
44+
per_act_token_quant=per_act_token_quant,
45+
)
3146
self.max_num_tokens = max_num_tokens
3247
self.world_size = world_size
3348
self.dp_size = dp_size
34-
self.use_fp8_w8a8 = use_fp8_w8a8
35-
self.use_int8_w8a8 = use_int8_w8a8
36-
self.use_int8_w8a16 = use_int8_w8a16
37-
self.use_int4_w4a16 = use_int4_w4a16
38-
self.per_channel_quant = per_channel_quant
39-
self.block_shape = block_shape
40-
self.allow_deep_gemm = allow_deep_gemm
41-
42-
# BatchedTritonKernel doesn't support block quantization
43-
# at the moment.
49+
4450
self.batched_triton_experts = BatchedTritonExperts(
4551
max_num_tokens=self.max_num_tokens,
46-
use_fp8_w8a8=self.use_fp8_w8a8,
47-
use_int8_w8a8=self.use_int8_w8a8,
48-
use_int8_w8a16=self.use_int8_w8a16,
49-
use_int4_w4a16=self.use_int4_w4a16,
50-
per_channel_quant=self.per_channel_quant,
51-
block_shape=self.block_shape,
5252
world_size=self.world_size,
53-
dp_size=self.dp_size) if self.block_shape is None else None
53+
dp_size=self.dp_size,
54+
use_fp8_w8a8=use_fp8_w8a8,
55+
use_int8_w8a8=use_int8_w8a8,
56+
use_int8_w8a16=use_int8_w8a16,
57+
use_int4_w4a16=use_int4_w4a16,
58+
per_act_token_quant=self.per_act_token_quant,
59+
block_shape=self.block_shape,
60+
)
61+
62+
dg_block_shape = [BatchedDeepGemmExperts.DEEPGEMM_BLOCK_SHAPE,
63+
BatchedDeepGemmExperts.DEEPGEMM_BLOCK_SHAPE]
64+
65+
self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8
66+
and self.block_shape == dg_block_shape)
5467

55-
is_fp8_128_block_quantized = (self.use_fp8_w8a8
56-
and self.block_shape is not None
57-
and len(self.block_shape) == 2 and all(
58-
[b == 128
59-
for b in self.block_shape]))
6068
self.batched_deep_gemm_experts = BatchedDeepGemmExperts(
6169
max_num_tokens=self.max_num_tokens,
6270
world_size=self.world_size,
6371
dp_size=self.dp_size,
64-
block_shape=self.block_shape, # type: ignore[arg-type]
65-
) if (self.allow_deep_gemm and is_fp8_128_block_quantized) else None
72+
block_shape=self.block_shape,
73+
) if self.allow_deep_gemm else None
74+
75+
assert (self.batched_triton_experts is not None or
76+
(self.allow_deep_gemm and self.batched_deep_gemm_experts is not None))
6677

6778
assert (self.batched_deep_gemm_experts is not None
6879
or self.batched_triton_experts is not None)
@@ -86,11 +97,10 @@ def workspace_shapes(
8697
# Note: the deep gemm workspaces are strictly larger than the triton
8798
# workspaces so we can be pessimistic here and allocate for DeepGemm
8899
# even if we fall back to triton later, e.g. if expert maps are set.
89-
if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None:
100+
if self.allow_deep_gemm:
90101
return self.batched_deep_gemm_experts.workspace_shapes(
91102
a, aq, M, N, K, topk, num_experts)
92103
else:
93-
assert self.batched_triton_experts is not None
94104
return self.batched_triton_experts.workspace_shapes(
95105
a, aq, M, N, K, topk, num_experts)
96106

@@ -118,7 +128,7 @@ def apply(
118128
and self.batched_deep_gemm_experts
119129
is not None)
120130
experts = (self.batched_deep_gemm_experts
121-
if use_batched_deep_gemm_experts else
131+
if self.allow_deep_gemm else
122132
self.batched_triton_experts)
123133
assert experts is not None
124134
experts.apply(output, hidden_states, w1, w2, topk_ids, activation,

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,7 @@ def run_cutlass_moe_fp8(
202202

203203

204204
# TODO (bnell): split class batched vs. non-batched?
205-
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
206-
205+
# maybe remove need for passing aq to workspace_shapes
207206
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
208207

209208
def __init__(
@@ -212,12 +211,13 @@ def __init__(
212211
out_dtype: Optional[torch.dtype],
213212
per_act_token_quant: bool,
214213
per_out_ch_quant: bool,
214+
block_shape: Optional[list[int]] = None,
215215
use_batched_format: bool = False,
216216
):
217217
super().__init__(
218218
quant_dtype=torch.float8_e4m3fn,
219219
per_act_token_quant=per_act_token_quant,
220-
block_shape=None,
220+
block_shape=block_shape,
221221
)
222222
self.max_experts_per_worker = max_experts_per_worker
223223
self.out_dtype = out_dtype
@@ -344,11 +344,7 @@ def cutlass_moe_fp8(
344344
out_dtype = a.dtype
345345

346346
fn = mk.FusedMoEModularKernel(
347-
MoEPrepareAndFinalizeNoEP(
348-
quant_dtype=torch.float8_e4m3fn,
349-
per_act_token_quant=per_act_token,
350-
block_shape=None,
351-
),
347+
MoEPrepareAndFinalizeNoEP(),
352348
CutlassExpertsFp8(
353349
max_experts_per_worker=global_num_experts,
354350
out_dtype=out_dtype,

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,7 @@ def deep_gemm_moe_fp8(
217217
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
218218
"""
219219
fn = mk.FusedMoEModularKernel(
220-
MoEPrepareAndFinalizeNoEP(quant_dtype=torch.float8_e4m3fn,
221-
per_act_token_quant=False,
222-
block_shape=deep_gemm_block_shape()),
220+
MoEPrepareAndFinalizeNoEP(),
223221
DeepGemmExperts(),
224222
)
225223
return fn(

vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,13 @@ def __init__(self,
2020
world_size: int,
2121
rank: int,
2222
dp_size: int,
23-
rank_expert_offset: int,
24-
quant_dtype: Optional[torch.dtype] = None,
25-
block_shape: Optional[list[int]] = None):
23+
rank_expert_offset: int):
2624
super().__init__()
2725
self.buffer = buffer
2826
self.world_size = world_size
2927
self.rank = rank
3028
self.dp_size = dp_size
3129
self.rank_expert_offset = rank_expert_offset
32-
self.quant_dtype = quant_dtype
33-
self.block_shape = block_shape
3430
# The dispatch function returns a handle that the combine function
3531
# requires. We store the handle here so it is available to the
3632
# combine function.
@@ -135,6 +131,9 @@ def prepare(
135131
num_experts: int,
136132
expert_map: Optional[torch.Tensor],
137133
apply_router_weight_on_input: bool,
134+
quant_dtype: Optional[torch.dtype],
135+
per_act_token_quant: bool,
136+
block_shape: Optional[list[int]],
138137
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
139138
Optional[torch.Tensor], Optional[torch.Tensor]]:
140139

vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,7 @@ def __init__(self,
4141
buffer: deep_ep.Buffer,
4242
world_size: int,
4343
dp_size: int,
44-
max_tokens_per_rank: int,
45-
quant_dtype: Optional[torch.dtype] = None,
46-
block_shape: Optional[list[int]] = None,
47-
use_fp8_dispatch: bool = False):
44+
max_tokens_per_rank: int):
4845
super().__init__()
4946

5047
self.buffer = buffer
@@ -123,6 +120,9 @@ def prepare(
123120
num_experts: int,
124121
expert_map: Optional[torch.Tensor],
125122
apply_router_weight_on_input: bool,
123+
quant_dtype: Optional[torch.dtype],
124+
per_act_token_quant: bool,
125+
block_shape: Optional[list[int]],
126126
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
127127
Optional[torch.Tensor], Optional[torch.Tensor]]:
128128

0 commit comments

Comments
 (0)