Skip to content

Commit 31b66d8

Browse files
committed
scale hacking
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent e5f2e2b commit 31b66d8

File tree

7 files changed

+172
-67
lines changed

7 files changed

+172
-67
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 67 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -204,29 +204,69 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
204204
config_block_shape = [16, 16, 32] # 16 for k if not fp8
205205

206206
#print(f"A {use_fp8_w8a8} {A_q.dtype} {B_q.dtype} {A_scale.shape} {B_scale.shape}")
207-
208-
invoke_moe_batched_triton_kernel(
209-
A_q,
210-
B_q,
211-
test_output,
212-
num_expert_tokens,
213-
compute_tl_dtype,
214-
# Quantization data
215-
A_scale,
216-
B_scale,
217-
None,
218-
# Quantization schemes
219-
use_fp8_w8a8,
220-
False,
221-
False,
222-
config={
223-
"BLOCK_SIZE_M": config_block_shape[0],
224-
"BLOCK_SIZE_N": config_block_shape[1],
225-
"BLOCK_SIZE_K": config_block_shape[2],
226-
},
227-
per_act_token_quant=False,
228-
block_shape=block_shape,
229-
)
207+
if False:
208+
from vllm.model_executor.layers.fused_moe.batched_moe2 import fused_moe_kernel2
209+
fused_moe_kernel2(
210+
A_q,
211+
B_q,
212+
test_output,
213+
A_scale,
214+
B_scale,
215+
num_expert_tokens,
216+
N,
217+
K,
218+
max_tokens_per_expert,
219+
max_tokens_per_expert,
220+
A_q.stride(0),
221+
A_q.stride(1),
222+
A_q.stride(2),
223+
B_q.stride(0),
224+
B_q.stride(1),
225+
B_q.stride(2),
226+
test_output.stride(0),
227+
test_output.stride(1),
228+
A_scale.stride(0),
229+
A_scale.stride(1),
230+
A_scale.stride(2),
231+
B_scale.stride(0),
232+
B_scale.stride(1),
233+
B_scale.stride(2),
234+
block_shape[0] if block_shape is not None else 0,
235+
block_shape[1] if block_shape is not None else 0,
236+
config_block_shape[0],
237+
config_block_shape[1],
238+
config_block_shape[2],
239+
1,
240+
1, # topk hack
241+
compute_tl_dtype,
242+
use_fp8_w8a8,
243+
False,
244+
False,
245+
per_channel_quant=False,
246+
)
247+
else:
248+
invoke_moe_batched_triton_kernel(
249+
A_q,
250+
B_q,
251+
test_output,
252+
num_expert_tokens,
253+
compute_tl_dtype,
254+
# Quantization data
255+
A_scale,
256+
B_scale,
257+
None,
258+
# Quantization schemes
259+
use_fp8_w8a8,
260+
False,
261+
False,
262+
config={
263+
"BLOCK_SIZE_M": config_block_shape[0],
264+
"BLOCK_SIZE_N": config_block_shape[1],
265+
"BLOCK_SIZE_K": config_block_shape[2],
266+
},
267+
per_act_token_quant=False,
268+
block_shape=block_shape,
269+
)
230270

231271
ref_output = ref_impl(
232272
A,
@@ -283,7 +323,7 @@ def per_block_cast_to_fp8(
283323
return x_scaled_sub, scales
284324

285325

286-
def make_test_weights(
326+
def _make_test_weights(
287327
e: int,
288328
n: int,
289329
k: int,
@@ -298,10 +338,10 @@ def make_test_weights(
298338
fp8_info = torch.finfo(torch.float8_e4m3fn)
299339
fp8_max, fp8_min = fp8_info.max, fp8_info.min
300340

301-
w1_bf16 = torch.randn((e, 2 * n, k), dtype=dtype) / 10
341+
w1_bf16 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
302342
w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
303343

304-
w2_bf16 = torch.randn((e, k, n), dtype=dtype) / 10
344+
w2_bf16 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
305345
w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
306346

307347
block_n, block_k = block_size[0], block_size[1]
@@ -330,7 +370,7 @@ def make_test_weights(
330370
return w1, w2, w1_s, w2_s, w1_bf16, w2_bf16
331371

332372

333-
def _make_test_weights(e, n, k, block_shape, dtype):
373+
def make_test_weights(e, n, k, block_shape, dtype):
334374
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
335375
w_dtype = torch.bfloat16 if use_fp8_w8a8 else dtype
336376

vllm/distributed/device_communicators/all2all.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,6 @@ def __init__(self, cpu_group):
8484
assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
8585
super().__init__(cpu_group)
8686

87-
# Intranode doesn't work yet.
88-
self.internode = True
89-
9087
if self.internode:
9188
# inter-node communication needs nvshmem,
9289
# intra-node communication uses p2p mapping directly

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def moe_mmk(
2929
stride_ak,
3030
stride_bk,
3131
stride_ase,
32-
stride_ask,
3332
stride_asm,
33+
stride_ask,
3434
stride_bse,
3535
stride_bsk,
3636
stride_bsn,
@@ -156,8 +156,8 @@ def expert_triton_kernel(
156156
stride_cm,
157157
stride_cn,
158158
stride_ase,
159-
stride_ask,
160159
stride_asm,
160+
stride_ask,
161161
stride_bse,
162162
stride_bsk,
163163
stride_bsn,
@@ -196,8 +196,8 @@ def expert_triton_kernel(
196196
stride_ak,
197197
stride_bk,
198198
stride_ase,
199-
stride_ask,
200199
stride_asm,
200+
stride_ask,
201201
stride_bse,
202202
stride_bsk,
203203
stride_bsn,
@@ -253,8 +253,8 @@ def batched_triton_kernel(
253253
stride_cm,
254254
stride_cn,
255255
stride_ase,
256-
stride_ask,
257256
stride_asm,
257+
stride_ask,
258258
stride_bse,
259259
stride_bsk,
260260
stride_bsn,
@@ -297,11 +297,11 @@ def batched_triton_kernel(
297297

298298
if use_fp8_w8a8:
299299
# block-wise
300-
if group_k > 0 and group_n > 0:
300+
if (group_k > 0 and group_n > 0) or per_channel_quant:
301301
a_scale_ptr = a_scale_ptr + (expert_id * stride_ase) + cta_m_start * stride_asm
302302
#b_scale_ptr = b_scale_ptr + (expert_id * stride_bse) # + cta_n_start * stride_bsn?
303-
# channel-wise
304-
elif per_channel_quant:
303+
# channel-wise or tensor-wise
304+
else:
305305
a_scale_ptr = a_scale_ptr + (expert_id * stride_ase)
306306
#b_scale_ptr = b_scale_ptr + (expert_id * stride_bse)
307307

@@ -325,8 +325,8 @@ def batched_triton_kernel(
325325
stride_cm,
326326
stride_cn,
327327
stride_ase,
328-
stride_ask,
329328
stride_asm,
329+
stride_ask,
330330
stride_bse,
331331
stride_bsk,
332332
stride_bsn,
@@ -373,6 +373,36 @@ def invoke_moe_batched_triton_kernel(
373373
grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) *
374374
triton.cdiv(B.size(1), BLOCK_N))
375375

376+
assert A_scale is None or A_scale.ndim == 1 or A_scale.ndim == 3, f"{0 if A_scale is None else (A_scale.ndim, A_scale.shape)}"
377+
assert B_scale is None or B_scale.ndim == 1 or B_scale.ndim == 3, f"{0 if B_scale is None else (B_scale.ndim, B_scale.shape)}"
378+
379+
#print(f"SCALES {A_scale.shape}, {B_scale.shape}")
380+
381+
stride_bse = 0
382+
stride_bsk = 0
383+
stride_bsn = 0
384+
if B_scale is not None:
385+
if B_scale.ndim == 1:
386+
stride_bsk = B_scale.stride(0)
387+
else:
388+
assert B_scale.ndim == 3
389+
stride_bse = B_scale.stride(0)
390+
stride_bsn = B_scale.stride(1)
391+
stride_bsk = B_scale.stride(2)
392+
393+
stride_ase = 0
394+
stride_asm = 0
395+
stride_ask = 0
396+
if A_scale is not None:
397+
if A_scale.ndim == 1:
398+
stride_ask = A_scale.stride(0)
399+
else:
400+
assert A_scale.ndim == 3
401+
stride_ase = A_scale.stride(0)
402+
stride_asm = A_scale.stride(1)
403+
stride_ask = A_scale.stride(2)
404+
405+
376406
batched_triton_kernel[grid](
377407
A,
378408
B,
@@ -397,15 +427,12 @@ def invoke_moe_batched_triton_kernel(
397427
C.stride(0),
398428
C.stride(1),
399429
C.stride(2),
400-
401-
A_scale.stride(0) if A_scale is not None and A_scale.ndim >= 2 else 0, #E
402-
A_scale.stride(2) if A_scale is not None and A_scale.ndim == 3 else 0, #K
403-
A_scale.stride(1) if A_scale is not None and A_scale.ndim >= 2 else 0, #M
404-
405-
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, #E
406-
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, #K
407-
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, #N
408-
430+
stride_ase,
431+
stride_asm,
432+
stride_ask,
433+
stride_bse,
434+
stride_bsk,
435+
stride_bsn,
409436
# Blockwise quantization data
410437
0 if block_shape is None else block_shape[0],
411438
0 if block_shape is None else block_shape[1],
@@ -537,7 +564,11 @@ def prepare(
537564

538565
tokens_per_expert[idx] = rows
539566

540-
return b_a1, b_a1_scale, tokens_per_expert, None, None
567+
#b_a1_scale.fill_(0.0001)
568+
#print(f"A1Q_scale = {b_a1_scale.shape}\n{b_a1_scale}")
569+
assert b_a1_scale is None or b_a1_scale.ndim == 3
570+
571+
return b_a1, b_a1_scale, tokens_per_expert
541572

542573
def finalize(
543574
self,

vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -97,27 +97,32 @@ def prepare(
9797
"apply_router_weight_on_input is only implemented for topk=1")
9898
a1 = a1 * rank_topk_weights.to(a1.dtype)
9999

100-
101-
repeat_cols = 4
102-
repeat_rows = 1 if self.per_act_token_quant else a1.shape[0]
103100
a1q, a1q_scale = moe_kernel_quantize_input(
104101
a1, (None if self.per_act_token_quant else a1_scale), self.quant_dtype,
105102
self.per_act_token_quant, self.block_shape)
106103

104+
# pplx requires 2-d scales even for scalars
107105
if a1q_scale is not None:
108-
a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols)
106+
if a1q_scale.dim() <= 1:
107+
assert a1q_scale.numel() == 1
108+
a1q_scale = a1q_scale.view(1, 1)
109+
110+
#print(f"ORIG {a1q_scale.shape}, {a1q_scale}")
111+
112+
orig_scale = a1q_scale
113+
orig_a1q_scale_shape = a1q_scale.shape
109114

110-
# per_act_token_quant = a1_scale.numel() != 1 if a1_scale is not None else (
111-
# a2_scale.numel() != 1 if a2_scale is not None else False)
115+
# pad out scales if needed
116+
if a1q_scale.numel() == 1:
117+
a1q_scale = a1q_scale.repeat(a1q.shape[1], 4)
112118

113-
# a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale,
114-
# self.quant_dtype,
115-
# per_act_token,
116-
# self.block_shape)
119+
assert a1q_scale.shape[0] == a1q.shape[1]
117120

118-
if a1q_scale is not None and a1q_scale.dim() == 1:
119-
assert a1q_scale.numel() == 1
120-
a1q_scale = a1q_scale.view(1, 1)
121+
#print(f"FINAL {a1q_scale.shape}, {a1q_scale}")
122+
123+
124+
assert a1q_scale is None or a1q_scale.ndim == 2, \
125+
f"{0 if a1q_scale is None else (a1q_scale.ndim, a1q_scale.shape)}"
121126

122127
# rem_experts need to be 0 for pplx to work properly.
123128
rem_experts = num_experts % self.world_size
@@ -147,7 +152,8 @@ def prepare(
147152
expert_x_scale_shape = (
148153
num_local_experts,
149154
expert_x.size(1),
150-
(expert_x.size(2) + block_size - 1) // block_size,
155+
#(expert_x.size(2) + block_size - 1) // block_size,
156+
orig_a1q_scale_shape[-1],
151157
)
152158

153159
#print(f"XXXXXXXXXX {block_size} {expert_x_scale_shape}")
@@ -176,9 +182,22 @@ def prepare(
176182
if expert_x_scale is not None:
177183
expert_x_scale = expert_x_scale[:, :, 0:1]
178184

179-
#print(f"ZZZZZZZZZZZZZZ")
185+
#print(f"ZZZZZZZZZZZZZZ {expert_x_scale.shape}")
180186
if expert_x_scale is not None:
181-
expert_x_scale = expert_x_scale[:, :, 0:1]
187+
expert_x_scale = expert_x_scale[:, :, :orig_a1q_scale_shape[-1]]
188+
from math import prod
189+
if prod(orig_a1q_scale_shape) == 1:
190+
expert_x_scale = expert_x_scale[:, :1, :1]
191+
#print(f"EPT {expert_num_tokens.flatten()}")
192+
#print(f"SCALARIZING!!! {expert_x_scale.shape}, {expert_x_scale.flatten()}")
193+
idx = expert_num_tokens.flatten() != 0
194+
assert torch.all(expert_x_scale.flatten()[idx] != 0)
195+
#zidx = expert_num_tokens.flatten() == 0
196+
#assert torch.all(expert_x_scale.flatten()[zidx] == 0)
197+
assert expert_x_scale.ndim == 3
198+
#expert_x_scale = orig_scale.view(1)
199+
200+
assert expert_x_scale.ndim == 1 or expert_x_scale.ndim == 3
182201

183202
return expert_x, expert_x_scale, expert_num_tokens, None, None
184203

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,12 @@ def select_gemm_impl(self, prepare_finalize, moe):
788788
use_batched_experts = max_num_tokens_per_rank is not None
789789

790790
if use_batched_experts:
791-
logger.debug("BatchedTritonExperts(fp8)")
791+
logger.debug(
792+
"BatchedTritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
793+
self.__class__.__name__,
794+
self.quant_config.weight_block_size,
795+
False
796+
)
792797
return BatchedTritonOrDeepGemmExperts(
793798
max_num_tokens=max_num_tokens_per_rank,
794799
world_size=prepare_finalize.world_size,
@@ -799,10 +804,16 @@ def select_gemm_impl(self, prepare_finalize, moe):
799804
allow_deep_gemm=self.allow_deep_gemm,
800805
)
801806
else:
802-
logger.debug("TritonOrDeepGemmExperts(fp8)")
807+
logger.debug(
808+
"TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
809+
self.__class__.__name__,
810+
self.quant_config.weight_block_size,
811+
False
812+
)
803813
return TritonOrDeepGemmExperts(
804814
use_fp8_w8a8=True,
805815
block_shape=self.quant_config.weight_block_size,
816+
per_act_token=False, #?
806817
allow_deep_gemm=self.allow_deep_gemm,
807818
)
808819

0 commit comments

Comments
 (0)