Skip to content

Commit 6f9b1e7

Browse files
committed
wip
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent c20591f commit 6f9b1e7

File tree

11 files changed

+114
-108
lines changed

11 files changed

+114
-108
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class BatchedMMTensors:
5656

5757
@staticmethod
5858
def make_tensors(config: BatchedMMConfig):
59-
if config.in_dtype == torch.torch.float8_e4m3fn:
59+
if config.in_dtype == torch.float8_e4m3fn:
6060
config_in_dtype = torch.bfloat16
6161
else:
6262
config_in_dtype = config.in_dtype
@@ -126,13 +126,13 @@ def ref_impl(
126126
def make_quantized_test_activations(E, m, k, dtype, block_shape, per_act_token):
127127
assert not per_act_token, "NYI"
128128

129-
a_type = torch.bfloat16 if dtype == torch.torch.float8_e4m3fn else dtype
129+
a_type = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
130130

131131
a = torch.randn((E, m, k), device="cuda", dtype=a_type) / 10
132132
a_q = a
133133
a_scale = None
134134

135-
if dtype == torch.torch.float8_e4m3fn:
135+
if dtype == torch.float8_e4m3fn:
136136
a_q = torch.zeros_like(a, dtype=dtype)
137137
a_scale = [None] * E
138138
for e in range(E):
@@ -153,13 +153,13 @@ def make_quantized_test_activations(E, m, k, dtype, block_shape, per_act_token):
153153
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
154154
@pytest.mark.parametrize(
155155
"dtype",
156-
[torch.torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16])
156+
[torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16])
157157
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
158158
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
159159
N: int, dtype: torch.dtype, block_shape: list[int]):
160160
current_platform.seed_everything(7)
161161

162-
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
162+
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
163163

164164
per_act_token_quant = False
165165

@@ -328,7 +328,7 @@ def _make_test_weights(
328328
n: int,
329329
k: int,
330330
block_size: list[int],
331-
dtype=torch.torch.float8_e4m3fn,
331+
dtype=torch.float8_e4m3fn,
332332
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
333333
"""
334334
Return weights w1, w2, w1q, w2q, w1_scale, w2_scale
@@ -371,7 +371,7 @@ def _make_test_weights(
371371

372372

373373
def make_test_weights(e, n, k, block_shape, dtype):
374-
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
374+
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
375375
w_dtype = torch.bfloat16 if use_fp8_w8a8 else dtype
376376

377377
w1_16 = torch.randn((e, 2 * n, k), device="cuda", dtype=w_dtype) / 15
@@ -456,7 +456,7 @@ def test_fused_moe_batched_experts(
456456
):
457457
current_platform.seed_everything(7)
458458

459-
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
459+
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
460460
quant_type = torch.float8_e4m3fn if use_fp8_w8a8 else None
461461

462462
if not use_fp8_w8a8 and per_act_token_quant and block_shape is not None:

tests/kernels/moe/test_pplx_cutlass_moe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,6 @@ def pplx_cutlass_moe(
118118
pgi.world_size,
119119
rank,
120120
dp_size,
121-
quant_dtype=torch.float8_e4m3fn,
122-
per_act_token=per_act_token,
123121
)
124122

125123
experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size,

vllm/model_executor/layers/fused_moe/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
from typing import Any, Optional
66

77
from vllm.model_executor.layers.fused_moe.layer import (
8-
FusedMoE, FusedMoEMethodBase,
9-
FusedMoeWeightScaleSupported)
8+
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
109
from vllm.triton_utils import HAS_TRITON
1110

1211
_config: Optional[dict[str, Any]] = None

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def __init__(
3636
assert self.block_shape == [self.DEEPGEMM_BLOCK_SHAPE, self.DEEPGEMM_BLOCK_SHAPE]
3737
super().__init__(
3838
quant_dtype=torch.float8_e4m3fn,
39-
block_shape=block_shape,
4039
per_act_token_quant=False,
40+
block_shape=block_shape,
4141
)
4242
self.max_num_tokens = max_num_tokens
4343
self.world_size = world_size

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,8 @@ def cutlass_moe_fp8(
348348
CutlassExpertsFp8(
349349
max_experts_per_worker=global_num_experts,
350350
out_dtype=out_dtype,
351-
per_act_token=per_act_token,
352-
per_out_ch=per_out_ch,
351+
per_act_token_quant=per_act_token,
352+
per_out_ch_quant=per_out_ch,
353353
use_batched_format=False,
354354
),
355355
)

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,11 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
6767
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
6868

6969
def __init__(self):
70-
super().__init__(torch.float8_e4m3fn, False, deep_gemm_block_shape())
70+
super().__init__(
71+
quant_dtype=torch.float8_e4m3fn,
72+
per_act_token_quant=False,
73+
block_shape=deep_gemm_block_shape()
74+
)
7175

7276
def supports_chunking(self) -> bool:
7377
return True

vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,6 @@ def _get_combine_config(self) -> Optional[deep_ep.Config]:
5151
return None
5252
return deep_ep.Buffer.get_combine_config(self.dp_size)
5353

54-
def _do_quant(self, tokens: torch.Tensor,
55-
token_scales: Optional[torch.Tensor], per_act_token: bool):
56-
tokens, token_scales = moe_kernel_quantize_input(
57-
tokens, token_scales, self.quant_dtype, per_act_token,
58-
self.block_shape)
59-
return tokens, token_scales
60-
6154
def _do_dispatch(self, tokens: torch.Tensor,
6255
token_scales: Optional[torch.Tensor],
6356
rank_topk_ids: torch.Tensor,
@@ -147,19 +140,25 @@ def prepare(
147140
# Check if there is a block_shape / or if we can infer the quantization
148141
# schemes from the scales.
149142
per_token_quant = None
150-
if all([x is None for x in [self.block_shape, a1_scale, a2_scale]
151-
]) and self.quant_dtype is not None:
143+
if all([x is None for x in [block_shape, a1_scale, a2_scale]
144+
]) and quant_dtype is not None:
152145
# Quantization required despite none of the inputs suggesting
153146
# quantization. Fallback to per_token_dynamic quant.
154147
per_token_quant = True
155148
else:
156-
per_token_quant = ((self.block_shape is not None) or
149+
per_token_quant = ((block_shape is not None) or
157150
(a1_scale is not None and a1_scale.numel() != 1)
158151
or (a2_scale is not None
159152
and a2_scale.numel() != 1))
160153

161154
if per_token_quant:
162-
a1q, a1q_scale = self._do_quant(a1, a1_scale, per_act_token=True)
155+
a1q, a1q_scale = moe_kernel_quantize_input(
156+
a1,
157+
a1_scale,
158+
quant_dtype=quant_dtype,
159+
per_act_token_quant=False,
160+
block_shape=block_shape,
161+
)
163162
(expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
164163
expert_topk_weights) = self._do_dispatch(
165164
tokens=a1q,
@@ -180,9 +179,13 @@ def prepare(
180179
# quantize now
181180
expert_x_scale = None
182181
if expert_x.numel() != 0:
183-
expert_x, expert_x_scale = self._do_quant(expert_x,
184-
a1_scale,
185-
per_act_token=False)
182+
expert_x, expert_x_scale = moe_kernel_quantize_input(
183+
expert_x,
184+
a1_scale,
185+
quant_dtype=quant_dtype,
186+
per_act_token=False,
187+
block_shape=block_shape
188+
)
186189

187190
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
188191
expert_topk_weights)

vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,21 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
3737
# specific hidden sizes.
3838
SUPPORTED_HIDDEN_SIZES = [2560, 4096, 5120, 7168]
3939

40-
def __init__(self,
41-
buffer: deep_ep.Buffer,
42-
world_size: int,
43-
dp_size: int,
44-
max_tokens_per_rank: int):
40+
def __init__(
41+
self,
42+
buffer: deep_ep.Buffer,
43+
max_tokens_per_rank: int,
44+
world_size: int,
45+
dp_size: int,
46+
use_fp8_w8a8: bool
47+
):
4548
super().__init__()
4649

4750
self.buffer = buffer
51+
self.max_tokens_per_rank = max_tokens_per_rank
4852
self.world_size = world_size
4953
self.dp_size = dp_size
50-
self.quant_dtype = quant_dtype
51-
self.block_shape = block_shape
52-
self.max_tokens_per_rank = max_tokens_per_rank
53-
self.use_fp8_dispatch = use_fp8_dispatch
54+
self.use_fp8_dispatch = use_fp8_w8a8
5455
# The dispatch function returns a handle that the combine function
5556
# requires. We store the handle here so it is available to the
5657
# combine function.
@@ -63,12 +64,17 @@ def topk_indices_dtype(self) -> Optional[torch.dtype]:
6364
return torch.int64
6465

6566
def _do_quant(
66-
self, x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
67-
a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
68-
a1_dtype: torch.dtype
67+
self,
68+
x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
69+
a1_scale: Optional[torch.Tensor],
70+
a2_scale: Optional[torch.Tensor],
71+
a1_dtype: torch.dtype,
72+
quant_dtype: Optional[torch.dtype],
73+
per_act_token_quant: bool,
74+
block_shape: Optional[list[int]],
6975
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
7076

71-
block_k = self.block_shape[1] if self.block_shape is not None else None
77+
block_k = block_shape[1] if block_shape is not None else None
7278
if self.use_fp8_dispatch:
7379
if block_k == DEEPEP_QUANT_BLOCK_SIZE:
7480
# DeepEP kernels did the quantization for us.
@@ -83,28 +89,33 @@ def _do_quant(
8389

8490
# Check if there is a block_shape / or if we can infer the quantization
8591
# schemes from the scales.
86-
per_token_quant = None
87-
if all([v is None for v in [self.block_shape, a1_scale, a2_scale]
88-
]) and self.quant_dtype is not None:
89-
# Quantization required despite none of the inputs suggesting
90-
# quantization. Fallback to per_token_dynamic quant.
91-
per_token_quant = True
92-
else:
93-
per_token_quant = ((self.block_shape is not None) or
94-
(a1_scale is not None and a1_scale.numel() != 1)
95-
or (a2_scale is not None
96-
and a2_scale.numel() != 1))
92+
# per_token_quant = None
93+
# if all([v is None for v in [block_shape, a1_scale, a2_scale]
94+
# ]) and quant_dtype is not None:
95+
# # Quantization required despite none of the inputs suggesting
96+
# # quantization. Fallback to per_token_dynamic quant.
97+
# per_token_quant = True
98+
# else:
99+
# per_token_quant = ((block_shape is not None) or
100+
# (a1_scale is not None and a1_scale.numel() != 1)
101+
# or (a2_scale is not None
102+
# and a2_scale.numel() != 1))
103+
104+
assert per_act_token_quant == ((block_shape is not None) or
105+
(a1_scale is not None and a1_scale.numel() != 1)
106+
or (a2_scale is not None
107+
and a2_scale.numel() != 1))
97108

98109
num_experts, max_tokens, hidden_dim = x.size()
99110

100111
# TODO (varun): Optimization - Use a batched version of quant
101112
x = x.view((-1, hidden_dim))
102-
x, x_scales = moe_kernel_quantize_input(x, a1_scale, self.quant_dtype,
103-
per_token_quant,
104-
self.block_shape)
113+
x, x_scales = moe_kernel_quantize_input(x, a1_scale, quant_dtype,
114+
per_act_token_quant,
115+
block_shape)
105116
x = x.view((num_experts, -1, hidden_dim))
106117

107-
if per_token_quant:
118+
if per_act_token_quant:
108119
assert x_scales is not None
109120
x_scales = x_scales.view(num_experts, max_tokens, -1)
110121

@@ -159,7 +170,10 @@ def prepare(
159170
return_recv_hook=False)
160171

161172
expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a2_scale,
162-
a1.dtype)
173+
a1.dtype,
174+
quant_dtype,
175+
per_act_token_quant,
176+
block_shape)
163177

164178
return (expert_x, expert_x_scale, expert_num_tokens, None, None)
165179

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def moe_mmk(
6161
if use_w8a8:
6262
# block-wise
6363
if group_k > 0 and group_n > 0:
64-
# + (expert_id * stride_ase) ??
6564
a_scale_ptrs = a_scale_ptr + (offs_m * stride_asm) #+ (expert_id * stride_ase)
6665
offs_bsn = offs_n // group_n
6766
b_scale_ptrs = (b_scale_ptr + offs_bsn * stride_bsn) + expert_id * stride_bse
@@ -376,12 +375,18 @@ def invoke_moe_batched_triton_kernel(
376375
triton.cdiv(B.size(1), BLOCK_N))
377376

378377
assert A_scale is None or A_scale.ndim == 3, f"{0 if A_scale is None else A_scale.shape}"
379-
assert B_scale is None or B_scale.ndim == 3, f"{0 if B_scale is None else B_scale.shape}"
378+
assert B_scale is None or B_scale.ndim == 1 or B_scale.ndim == 3, f"{0 if B_scale is None else B_scale.shape}"
379+
#assert B_scale is None or B_scale.ndim == 3, f"{0 if B_scale is None else (A.shape, B_scale.shape)}"
380380

381381
if B_scale is not None:
382-
stride_bse = B_scale.stride(0)
383-
stride_bsn = B_scale.stride(1)
384-
stride_bsk = B_scale.stride(2)
382+
if B_scale.ndim == 1:
383+
stride_bse = 1
384+
stride_bsn = 0
385+
stride_bsk = 0
386+
else:
387+
stride_bse = B_scale.stride(0)
388+
stride_bsn = B_scale.stride(1)
389+
stride_bsk = B_scale.stride(2)
385390
else:
386391
stride_bse = 0
387392
stride_bsk = 0
@@ -509,7 +514,7 @@ def prepare(
509514
device=a1.device)
510515

511516
if quant_dtype is not None:
512-
if self.block_shape is not None:
517+
if block_shape is not None:
513518
_, block_k = block_shape
514519
k_tiles = (hidden_dim + block_k - 1) // block_k
515520
scale_shape = (num_local_experts, self.max_num_tokens, k_tiles)

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1626,9 +1626,9 @@ def __init__(
16261626
)
16271627

16281628
super().__init__(
1629-
quant_dtype,
1630-
per_act_token_quant,
1631-
block_shape,
1629+
quant_dtype=quant_dtype,
1630+
per_act_token_quant=per_act_token_quant,
1631+
block_shape=block_shape,
16321632
)
16331633

16341634
self.use_fp8_w8a8 = use_fp8_w8a8
@@ -1762,7 +1762,7 @@ def apply(
17621762
a2q_scale: Optional[torch.Tensor] = None
17631763

17641764
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
1765-
intermediate_cache2, a2_scale, self.qtype,
1765+
intermediate_cache2, a2_scale, self.quant_dtype,
17661766
self.per_act_token_quant, self.block_shape)
17671767

17681768
invoke_fused_moe_kernel(qintermediate_cache2,
@@ -1795,12 +1795,6 @@ def modular_triton_fused_moe(
17951795
per_act_token_quant: bool,
17961796
block_shape: Optional[list[int]] = None,
17971797
) -> mk.FusedMoEModularKernel:
1798-
quant_dtype = get_config_quant_dtype(
1799-
use_fp8_w8a8=use_fp8_w8a8,
1800-
use_int8_w8a8=use_int8_w8a8,
1801-
use_int8_w8a16=use_int8_w8a16,
1802-
use_int4_w4a16=use_int4_w4a16,
1803-
)
18041798
return mk.FusedMoEModularKernel(
18051799
MoEPrepareAndFinalizeNoEP(),
18061800
TritonExperts(

0 commit comments

Comments
 (0)