Skip to content

Commit d7bb199

Browse files
committed
prepare_finalize wokring
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent eab92d3 commit d7bb199

File tree

4 files changed

+126
-49
lines changed

4 files changed

+126
-49
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,12 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
8787

8888
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
8989

90-
if block_shape is not None and not use_fp8_w8a8:
90+
if (per_act_token_quant or block_shape is not None) and not use_fp8_w8a8:
9191
pytest.skip("Don't test blocking for non-quantized types.")
9292

93+
if per_act_token_quant and block_shape is not None:
94+
pytest.skip("Illegal quantization.")
95+
9396
if dtype.itemsize == 1:
9497
act_dtype = torch.bfloat16
9598
quant_dtype = dtype
@@ -182,7 +185,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
182185

183186
torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol)
184187
#torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol)
185-
#torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)
188+
torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)
186189

187190

188191
@pytest.mark.parametrize("m", [1, 32, 45, 64, 222])
@@ -213,7 +216,7 @@ def test_fused_moe_batched_experts(
213216
if not use_fp8_w8a8 and per_act_token_quant and block_shape is not None:
214217
pytest.skip("Skip quantization test for non-quantized type")
215218

216-
if per_act_token_quant and block_shape is not None or topk > e:
219+
if (per_act_token_quant and block_shape is not None) or topk > e:
217220
pytest.skip("Skip illegal quantization test")
218221

219222
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10

tests/kernels/moe/test_pplx_moe.py

Lines changed: 71 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from tests.kernels.moe.utils import make_test_weights, naive_batched_moe
2222
from tests.kernels.utils import torch_experts
23+
from tests.kernels.quant_utils import dequant
2324
from vllm.config import VllmConfig, set_current_vllm_config
2425
from vllm.model_executor.layers.fused_moe import fused_topk, override_config
2526
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
@@ -39,8 +40,14 @@
3940
reason="Requires PPLX kernels",
4041
)
4142

42-
PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
43-
(222, 2048, 1024)]
43+
PPLX_PREPARE_COMBOS = [
44+
# (1, 128, 128),
45+
(4, 128, 128),
46+
(32, 1024, 512),
47+
# (45, 512, 2048),
48+
(64, 1024, 512),
49+
(222, 2048, 1024),
50+
]
4451

4552
PPLX_MOE_COMBOS = [
4653
(1, 128, 128),
@@ -194,18 +201,24 @@ def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
194201
return t[(r * chunk):(r + 1) * chunk]
195202

196203

204+
def dummy_work(a: torch.Tensor) -> torch.Tensor:
205+
return a # * 1.5
206+
207+
197208
def pplx_prepare_finalize(
198209
pgi: ProcessGroupInfo,
199210
dp_size: int,
200211
a: torch.Tensor,
201-
a_scale: Optional[torch.Tensor],
202212
topk_weight: torch.Tensor,
203213
topk_ids: torch.Tensor,
204214
num_experts: int,
215+
quant_dtype: Optional[torch.dtype],
216+
block_shape: Optional[list[int]],
217+
per_act_token_quant: bool,
205218
group_name: Optional[str],
206219
) -> torch.Tensor:
207220
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
208-
PplxPrepareAndFinalize)
221+
PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes)
209222

210223
assert torch.cuda.current_device() == pgi.local_rank
211224

@@ -214,7 +227,16 @@ def pplx_prepare_finalize(
214227
device = pgi.device
215228
rank = pgi.rank
216229
world_size = pgi.world_size
217-
max_num_tokens = rank_chunk(num_tokens, 0, world_size)
230+
max_num_tokens = max(rank_chunk(num_tokens, 0, world_size), 1)
231+
232+
hidden_dim_bytes, scale_bytes = pplx_hidden_dim_scale_bytes(
233+
max_num_tokens,
234+
hidden_dim,
235+
a.dtype,
236+
quant_dtype,
237+
per_act_token_quant=per_act_token_quant,
238+
block_shape=block_shape,
239+
)
218240

219241
args = dict(
220242
max_num_tokens=max_num_tokens,
@@ -224,8 +246,8 @@ def pplx_prepare_finalize(
224246
world_size=world_size,
225247
dp_size=dp_size,
226248
hidden_dim=hidden_dim,
227-
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
228-
hidden_dim_scale_bytes=0,
249+
hidden_dim_bytes=hidden_dim_bytes,
250+
hidden_dim_scale_bytes=scale_bytes,
229251
)
230252

231253
if group_name is None:
@@ -257,10 +279,17 @@ def pplx_prepare_finalize(
257279
num_experts,
258280
None,
259281
False,
260-
FusedMoEQuantConfig(),
282+
FusedMoEQuantConfig(
283+
quant_dtype,
284+
per_act_token_quant,
285+
False,
286+
block_shape,
287+
),
261288
)
262289

263-
b_a = b_a * 1.5
290+
# Do some fake work
291+
#print(f"INTER {b_a.shape} {b_a_scale.shape if b_a_scale is not None else None}")
292+
b_a = dummy_work(dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype))
264293

265294
out = torch.full(
266295
(max_num_tokens, hidden_dim),
@@ -290,10 +319,12 @@ def _pplx_prepare_finalize(
290319
pgi: ProcessGroupInfo,
291320
dp_size: int,
292321
a: torch.Tensor,
293-
a_scale: Optional[torch.Tensor],
294322
score: torch.Tensor,
295323
topk: torch.Tensor,
296324
num_experts: int,
325+
quant_dtype: Optional[torch.dtype],
326+
block_shape: Optional[list[int]],
327+
per_act_token_quant: bool,
297328
use_internode: bool,
298329
):
299330
if use_internode:
@@ -307,24 +338,35 @@ def _pplx_prepare_finalize(
307338
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
308339
group_name = cpu_group.group_name
309340

310-
device = pgi.device
341+
#device = pgi.device
311342

312343
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
313-
k = a.shape[1]
344+
m, k = a.shape
314345

315-
a_rep = torch.repeat_interleave(a, topk, dim=0).to(device)
346+
a_rep = torch.repeat_interleave(dummy_work(a), topk, dim=0) #.to(device)
316347

317-
torch_output = (a_rep.view(-1, topk, k) * 1.5 *
318-
topk_weight.view(-1, topk, 1).to(device)).sum(dim=1).to(
319-
a.dtype)
348+
if True:
349+
torch_output = (a_rep.view(m, topk, k) *
350+
topk_weight.view(m, topk, 1).to(a_rep.dtype)).sum(dim=1)
351+
else:
352+
import vllm._custom_ops as ops
353+
a_rep = a_rep.view(m, topk, k)
354+
a_rep.mul_(topk_weight.view(m, topk, 1).to(a_rep.dtype))
355+
torch_output = torch.empty_like(a)
356+
ops.moe_sum(a_rep, torch_output)
320357

321-
pplx_output = pplx_prepare_finalize(pgi, dp_size, a, a_scale, topk_weight, topk_ids,
322-
num_experts, group_name)
358+
pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, topk_ids,
359+
num_experts, quant_dtype, block_shape,
360+
per_act_token_quant, group_name)
323361

324362
torch_output = chunk_by_rank(torch_output, pgi.rank,
325363
pgi.world_size).to(pplx_output.device)
326364

327-
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
365+
#torch.set_printoptions(profile="full")
366+
#print(f"PPLX {pplx_output.shape}\n{pplx_output.shape}")
367+
#print(f"TORCH {torch_output.shape}\n{torch_output.shape}")
368+
369+
torch.testing.assert_close(pplx_output, torch_output, atol=3e-2, rtol=3e-2)
328370

329371
if use_internode:
330372
nvshmem_finalize()
@@ -333,11 +375,10 @@ def _pplx_prepare_finalize(
333375
# TODO (bnell): this test point does not work for odd M due to how the test is
334376
# written, not due to limitations of the pplx kernels. The pplx_moe
335377
# test below is able to deal with odd M.
336-
# TODO (bnell) add fp8 tests
337378
@pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS)
338379
@pytest.mark.parametrize("e", NUM_EXPERTS)
339380
@pytest.mark.parametrize("topk", TOP_KS)
340-
@pytest.mark.parametrize("dtype", [torch.bfloat16]) # torch.float8_e4m3fn,
381+
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
341382
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
342383
@pytest.mark.parametrize("per_act_token_quant", [False])
343384
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@@ -356,28 +397,31 @@ def test_pplx_prepare_finalize(
356397
if dtype == torch.float8_e4m3fn:
357398
use_fp8_w8a8 = True
358399
act_dtype = torch.bfloat16
400+
quant_dtype = dtype
359401
else:
360402
use_fp8_w8a8 = False
361403
act_dtype = dtype
404+
quant_dtype = None
362405

363406
if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
364407
pytest.skip("Skip quantization test for non-quantized type")
365408

366409
if per_act_token_quant and block_shape is not None:
367-
pytest.skip("Skip illgal quantization combination")
410+
pytest.skip("Skip illegal quantization combination")
368411

369412
current_platform.seed_everything(7)
370413
m, n, k = mnk
371414
world_size, dp_size = world_dp_size
372415
device = "cuda"
373416

417+
#print(f"MNK = {mnk}")
418+
374419
a = torch.randn((m, k), device=device, dtype=act_dtype) / 10
375420
score = torch.randn((m, e), device=device, dtype=act_dtype)
376421

377-
a, a_scale = moe_kernel_quantize_input(a, None, dtype, False, block_shape)
378-
379-
parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, a_scale, score,
380-
topk, e, use_internode)
422+
parallel_launch(world_size, _pplx_prepare_finalize, dp_size,
423+
a, score, topk, e, quant_dtype, block_shape,
424+
per_act_token_quant, use_internode)
381425

382426

383427
def pplx_moe(
@@ -661,7 +705,7 @@ def test_pplx_moe(
661705
pytest.skip("Skip quantization test for non-quantized type")
662706

663707
if per_act_token_quant and block_shape is not None:
664-
pytest.skip("Skip illgal quantization combination")
708+
pytest.skip("Skip illegal quantization combination")
665709

666710
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
667711
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)

tests/kernels/quant_utils.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -236,12 +236,21 @@ def per_block_cast_to_fp8(
236236
return x_scaled_sub, scales
237237

238238

239-
def _dequant(t: torch.Tensor, scale: torch.Tensor, block_shape, per_act_token_quant) -> torch.Tensor:
240-
f32 = torch.float32
241-
if per_act_token_quant or block_shape is None:
242-
return t.to(f32) * scale
239+
def dequant(
240+
t: torch.Tensor, scale:
241+
Optional[torch.Tensor],
242+
block_shape: Optional[list[int]],
243+
per_act_token_quant: bool,
244+
out_dtype: Optional[torch.dtype] = torch.float32,
245+
) -> torch.Tensor:
246+
if scale is not None:
247+
f32 = torch.float32
248+
if per_act_token_quant or block_shape is None:
249+
return (t.to(f32) * scale).to(out_dtype)
250+
else:
251+
return (t.to(f32) * group_broadcast(scale, t.shape)).to(out_dtype)
243252
else:
244-
return t.to(f32) * group_broadcast(scale, t.shape)
253+
return t.to(out_dtype)
245254

246255

247256
def native_batched_masked_quant_matmul(
@@ -269,8 +278,8 @@ def native_batched_masked_quant_matmul(
269278
C[e, :num_tokens, :] = tmp[:num_tokens, :]
270279
elif A.dtype.itemsize == 1 and block_shape is None:
271280
assert A_scale is not None and B_scale is not None
272-
A_dq = _dequant(A[e], A_scale[e], block_shape, per_act_token_quant)
273-
B_dq = _dequant(B[e], B_scale[e], block_shape, per_act_token_quant)
281+
A_dq = dequant(A[e], A_scale[e], block_shape, per_act_token_quant)
282+
B_dq = dequant(B[e], B_scale[e], block_shape, per_act_token_quant)
274283
C[e, :num_tokens, :] = (A_dq[:num_tokens] @ B_dq.transpose(0, 1)).to(C.dtype)
275284
else:
276285
assert A_scale is None

vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,17 @@ def prepare(
127127
quant_config.quant_dtype, quant_config.per_act_token_quant,
128128
quant_config.block_shape)
129129

130+
if quant_config.quant_dtype is not None:
131+
if quant_config.is_per_tensor:
132+
assert a1q_scale.numel() == 1
133+
elif quant_config.is_per_act_token:
134+
assert a1q_scale.numel() == a1.numel()
135+
assert a1q_scale.shape == a1.shape
136+
else:
137+
assert a1q_scale.numel() == a1.shape[0] * cdiv(a1.shape[1], quant_config.block_shape[1])
138+
assert a1q_scale.shape == (a1.shape[0], cdiv(a1.shape[1], quant_config.block_shape[1]))
139+
#a1q_scale = group_broadcast(scale, a1q.shape)
140+
130141
if a1q_scale is not None:
131142
scalar_scales = a1q_scale.numel() == 1
132143

@@ -138,15 +149,21 @@ def prepare(
138149
orig_a_scale_block_shape = a1q_scale.shape[-1]
139150

140151
# pad out scales if needed. TODO (bnell): do for non-scalar scales?
141-
if False and scalar_scales:
142-
print(f"a1q_scale {a1q.shape}, {a1q_scale.shape}")
143-
a1q_scale = a1q_scale.repeat(a1q.shape[1],
144-
4 * torch.float32.itemsize)
152+
if False and (scalar_scales or quant_config.is_per_tensor):
153+
#print(f"a1q_scale {a1q.shape}, {a1q_scale.shape}")
154+
a1q_scale = a1q_scale.repeat(1, 4 * torch.float32.itemsize)
155+
else:
156+
#a1q_scale = torch.repeat_interleave(a1q_scale, round_up(a1q_scale.shape[1], 16), dim=1)
157+
#a1q_scale = torch.nn.functional.pad(a1q_scale, pad=(0, 16-a1q_scale.shape[1]), mode='replicate')
158+
pass
145159

146-
a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols)
160+
if not quant_config.is_grouped:
161+
a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols)
147162

148163
#assert a1_scale is None or a1_scale.shape[0] == a1q.shape[1], f"{a1_scale.shape}, {a1q_scale.shape}"
149164

165+
#print(f"FINAL SCALE SHAPE {a1q_scale.shape}")
166+
150167
assert a1q_scale is None or a1q_scale.ndim == 2, \
151168
f"{0 if a1q_scale is None else (a1q_scale.ndim, a1q_scale.shape)}"
152169

@@ -173,16 +190,23 @@ def prepare(
173190
expert_x_scale: Optional[torch.Tensor] = None
174191
if a1q.dtype.itemsize == 1:
175192
float32_size = torch.float32.itemsize
176-
block_size = (quant_config.block_shape[1] if quant_config.
177-
block_shape is not None else 1) * float32_size
193+
194+
if quant_config.is_per_act_token:
195+
final_dim = expert_x.size(2)
196+
assert final_dim % 4 == 0 #?
197+
elif quant_config.is_per_tensor:
198+
final_dim = 4
199+
else:
200+
num_blocks = cdiv(expert_x.size(2), quant_config.block_shape[1])
201+
final_dim = round_up(num_blocks, 4)
178202

179203
expert_x_scale_shape = (
180204
num_local_experts,
181205
expert_x.size(1),
182-
cdiv(expert_x.size(2), block_size) if not scalar_scales else 1,
206+
final_dim,
183207
)
184208

185-
print(f"EXPERT_X_SCALE {expert_x_scale_shape}")
209+
#print(f"EXPERT_X_SCALE {expert_x_scale_shape}")
186210

187211
expert_x_scale = torch.zeros(
188212
expert_x_scale_shape,
@@ -207,9 +231,6 @@ def prepare(
207231
)
208232
#print(f"DISPATCH DONE {device}")
209233

210-
if expert_x_scale is not None:
211-
expert_x_scale = expert_x_scale[:, :, 0:1]
212-
213234
if expert_x_scale is not None:
214235
expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
215236
assert expert_x_scale.ndim == 3

0 commit comments

Comments
 (0)