Skip to content

Commit 604ab02

Browse files
committed
refactoring
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent db773b0 commit 604ab02

18 files changed

+674
-566
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 1 addition & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
torch_moe2,
2828
triton_moe,
2929
batched_moe,
30+
make_test_weights,
3031
)
3132

3233
NUM_EXPERTS = [8, 64]
@@ -302,27 +303,6 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
302303
torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)
303304

304305

305-
# Move to utils
306-
def per_block_cast_to_fp8(
307-
x: torch.Tensor,
308-
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
309-
from vllm.utils import cdiv
310-
assert x.dim() == 2
311-
m, n = x.shape
312-
x_padded = torch.zeros(
313-
(cdiv(m, 128) * 128,
314-
cdiv(n, block_size_n) * block_size_n),
315-
dtype=x.dtype,
316-
device=x.device)
317-
x_padded[:m, :n] = x
318-
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
319-
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
320-
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
321-
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
322-
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
323-
return x_scaled_sub, scales
324-
325-
326306
def _make_test_weights(
327307
e: int,
328308
n: int,
@@ -370,67 +350,6 @@ def _make_test_weights(
370350
return w1, w2, w1_s, w2_s, w1_bf16, w2_bf16
371351

372352

373-
def make_test_weights(e, n, k, block_shape, dtype):
374-
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
375-
w_dtype = torch.bfloat16 if use_fp8_w8a8 else dtype
376-
377-
w1_16 = torch.randn((e, 2 * n, k), device="cuda", dtype=w_dtype) / 15
378-
w2_16 = torch.randn((e, k, n), device="cuda", dtype=w_dtype) / 15
379-
380-
if use_fp8_w8a8:
381-
w1_l = [None] * e
382-
w2_l = [None] * e
383-
w1_s = [None] * e
384-
w2_s = [None] * e
385-
for idx in range(e):
386-
if block_shape is not None:
387-
w1_l[idx], w1_s[idx] = per_block_cast_to_fp8(
388-
w1_16[idx],
389-
block_shape[1],
390-
)
391-
w2_l[idx], w2_s[idx] = per_block_cast_to_fp8(
392-
w2_16[idx],
393-
block_shape[1],
394-
)
395-
else:
396-
tmp, w1_s[idx] = per_token_group_quant_fp8(
397-
w1_16[idx].view(1, -1),
398-
w1_16[idx].numel()
399-
)
400-
w1_l[idx] = tmp.view(*w1_16[idx].shape)
401-
402-
tmp, w2_s[idx] = per_token_group_quant_fp8(
403-
w2_16[idx].view(1, -1),
404-
w2_16[idx].numel()
405-
)
406-
w2_l[idx] = tmp.view(*w2_16[idx].shape)
407-
408-
w1 = torch.stack(w1_l)
409-
w2 = torch.stack(w2_l)
410-
w1_s = torch.stack(w1_s)
411-
w2_s = torch.stack(w2_s)
412-
if w1_s.ndim == 2:
413-
assert w1_s.shape[-1] == 1
414-
w1_s = w1_s.view(-1, 1, 1)
415-
w2_s = w2_s.view(-1, 1, 1)
416-
417-
if block_shape is not None:
418-
block_n, block_k = block_shape
419-
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n
420-
k_tiles_w1 = (k + block_k - 1) // block_k
421-
n_tiles_w2 = (k + block_n - 1) // block_n
422-
k_tiles_w2 = (n + block_k - 1) // block_k
423-
assert w1_s.shape == (e, n_tiles_w1, k_tiles_w1)
424-
assert w2_s.shape == (e, n_tiles_w2, k_tiles_w2)
425-
else:
426-
w1 = w1_16
427-
w2 = w2_16
428-
w1_s = None
429-
w2_s = None
430-
431-
return w1, w2, w1_s, w2_s, w1_16, w2_16
432-
433-
434353
@pytest.mark.parametrize("m", [32, 45, 64]) #[1, 33, 64, 222])
435354
@pytest.mark.parametrize("n", [128, 512, 1024, 2048])
436355
@pytest.mark.parametrize("k", [128, 512, 1024, 2048])

tests/kernels/moe/test_pplx_moe.py

Lines changed: 8 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from tests.kernels.moe.utils import (
4040
torch_moe2,
4141
naive_batched_moe,
42+
make_test_weights,
4243
)
4344

4445

@@ -264,7 +265,7 @@ def pplx_prepare_finalize(
264265
chunk_topk_ids,
265266
num_experts,
266267
None,
267-
False,
268+
FusedMoEConfig(),
268269
)
269270

270271
b_a = b_a * 1.5
@@ -583,7 +584,7 @@ def _pplx_moe(
583584
with set_current_vllm_config(vllm_config), override_config(moe_config):
584585
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
585586
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s, w2_s,
586-
use_fp8_w8a8, per_act_token_quant,
587+
qtype, per_act_token_quant,
587588
block_shape)
588589
pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size, a,
589590
w1, w2, topk_weight, topk_ids, w1_s, w2_s, qtype,
@@ -624,69 +625,16 @@ def test_pplx_moe(
624625
current_platform.seed_everything(7)
625626
m, n, k = mnk
626627
world_size, dp_size = world_dp_size
627-
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
628-
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16) / 10
629-
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10
630-
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
631-
632628
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
633629

634630
if not use_fp8_w8a8 and per_act_token_quant and block_shape is not None:
635631
pytest.skip("Skip quantization test for non-quantized type")
636632

637-
# TODO (bnell): scale setup for different quant strategies?
638-
if use_fp8_w8a8:
639-
quant_type = torch.float8_e4m3fn
640-
641-
#finfo = torch.finfo(dtype)
642-
#fp8_min = finfo.min
643-
#fp8_max = finfo.max
644-
#w1 = w1.clamp(min=fp8_min, max=fp8_max).to(dtype)
645-
#w2 = w2.clamp(min=fp8_min, max=fp8_max).to(dtype)
646-
# block_n, block_k = block_shape[0], block_shape[1]
647-
# n_tiles_w1 = (2 * n + block_n - 1) // block_n
648-
# n_tiles_w2 = (k + block_n - 1) // block_n
649-
# k_tiles_w1 = (k + block_k - 1) // block_k
650-
# k_tiles_w2 = (n + block_k - 1) // block_k
651-
# factor_for_scale = 1e-2
652-
# w1_s = torch.rand(
653-
# (e, n_tiles_w1, k_tiles_w1), dtype=torch.float32,
654-
# device="cuda") * factor_for_scale
655-
# w2_s = torch.rand(
656-
# (e, n_tiles_w2, k_tiles_w2), dtype=torch.float32,
657-
# device="cuda") * factor_for_scale
658-
w1_l = [None] * e
659-
w2_l = [None] * e
660-
w1_s = [None] * e
661-
w2_s = [None] * e
662-
for idx in range(e):
663-
w1_l[idx], w1_s[idx] = moe_kernel_quantize_input(
664-
w1[idx],
665-
None,
666-
quant_type,
667-
per_act_token_quant,
668-
block_shape
669-
)
670-
w2_l[idx], w2_s[idx] = moe_kernel_quantize_input(
671-
w2[idx],
672-
None,
673-
quant_type,
674-
per_act_token_quant,
675-
block_shape
676-
)
677-
w1 = torch.stack(w1_l)
678-
w2 = torch.stack(w2_l)
679-
w1_s = torch.stack(w1_s)
680-
w2_s = torch.stack(w2_s)
681-
if w1_s.ndim == 2:
682-
assert w1_s.shape[-1] == 1
683-
w1_s = w1_s.view(-1, 1, 1)
684-
w2_s = w2_s.view(-1, 1, 1)
685-
else:
686-
quant_type = None
687-
w1_s = None
688-
w2_s = None
633+
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
634+
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
635+
636+
w1, w2, w1_s, w2_s, w1_16, w2_16 = make_test_weights(e, n, k, block_shape, dtype)
689637

690638
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk,
691-
w1_s, w2_s, quant_type, per_act_token_quant, block_shape,
639+
w1_s, w2_s, dtype, per_act_token_quant, block_shape,
692640
use_internode)

tests/kernels/moe/utils.py

Lines changed: 103 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
BatchedPrepareAndFinalize,
1212
BatchedTritonExperts,
1313
NaiveBatchedExperts)
14+
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
15+
w8a8_block_fp8_matmul,
16+
per_token_group_quant_fp8)
17+
1418
from vllm.utils import round_up
1519

1620
from tests.kernels.quant_utils import native_w8a8_block_matmul
@@ -112,6 +116,8 @@ def torch_moe2(
112116
block_shape
113117
)
114118

119+
print(f"XXX {quant_type} {block_shape} {a.shape} {a_scale}")
120+
115121
out = torch.zeros(M * topk,
116122
w2.shape[1],
117123
dtype=torch.bfloat16,
@@ -129,8 +135,14 @@ def torch_moe2(
129135
tmp2 = SiluAndMul()(tmp1)
130136
out[mask] = tmp2 @ w2[i].transpose(0, 1)
131137
elif block_shape is not None:
132-
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
133-
w1_scale[i], block_shape, out.dtype)
138+
tmp1 = native_w8a8_block_matmul(
139+
a[mask],
140+
w1[i],
141+
a_scale[mask],
142+
w1_scale[i],
143+
block_shape,
144+
out.dtype
145+
)
134146

135147
#print(f"TORCH INTER[{i}] {tmp1.shape}\n{tmp1}")
136148
#inters[i, :tmp1.shape[0]] = tmp1
@@ -144,9 +156,14 @@ def torch_moe2(
144156
per_act_token_quant,
145157
block_shape)
146158

147-
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
148-
w2_scale[i], block_shape,
149-
out.dtype)
159+
out[mask] = native_w8a8_block_matmul(
160+
tmp2,
161+
w2[i],
162+
b_scale,
163+
w2_scale[i],
164+
block_shape,
165+
out.dtype
166+
)
150167
else:
151168
# XXXX need scales here
152169
compute_type = torch.bfloat16
@@ -237,3 +254,84 @@ def naive_batched_moe(
237254

238255
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
239256

257+
258+
# Move to utils
259+
def per_block_cast_to_fp8(
260+
x: torch.Tensor,
261+
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
262+
from vllm.utils import cdiv
263+
assert x.dim() == 2
264+
m, n = x.shape
265+
x_padded = torch.zeros(
266+
(cdiv(m, 128) * 128,
267+
cdiv(n, block_size_n) * block_size_n),
268+
dtype=x.dtype,
269+
device=x.device)
270+
x_padded[:m, :n] = x
271+
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
272+
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
273+
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
274+
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
275+
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
276+
return x_scaled_sub, scales
277+
278+
279+
def make_test_weights(e, n, k, block_shape, dtype):
280+
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
281+
w_dtype = torch.bfloat16 if use_fp8_w8a8 else dtype
282+
283+
w1_16 = torch.randn((e, 2 * n, k), device="cuda", dtype=w_dtype) / 15
284+
w2_16 = torch.randn((e, k, n), device="cuda", dtype=w_dtype) / 15
285+
286+
if use_fp8_w8a8:
287+
w1_l = [None] * e
288+
w2_l = [None] * e
289+
w1_s = [None] * e
290+
w2_s = [None] * e
291+
for idx in range(e):
292+
if block_shape is not None:
293+
w1_l[idx], w1_s[idx] = per_block_cast_to_fp8(
294+
w1_16[idx],
295+
block_shape[1],
296+
)
297+
w2_l[idx], w2_s[idx] = per_block_cast_to_fp8(
298+
w2_16[idx],
299+
block_shape[1],
300+
)
301+
else:
302+
tmp, w1_s[idx] = per_token_group_quant_fp8(
303+
w1_16[idx].view(1, -1),
304+
w1_16[idx].numel()
305+
)
306+
w1_l[idx] = tmp.view(*w1_16[idx].shape)
307+
308+
tmp, w2_s[idx] = per_token_group_quant_fp8(
309+
w2_16[idx].view(1, -1),
310+
w2_16[idx].numel()
311+
)
312+
w2_l[idx] = tmp.view(*w2_16[idx].shape)
313+
314+
w1 = torch.stack(w1_l)
315+
w2 = torch.stack(w2_l)
316+
w1_s = torch.stack(w1_s)
317+
w2_s = torch.stack(w2_s)
318+
if w1_s.ndim == 2:
319+
assert w1_s.shape[-1] == 1
320+
w1_s = w1_s.view(-1, 1, 1)
321+
w2_s = w2_s.view(-1, 1, 1)
322+
323+
if block_shape is not None:
324+
block_n, block_k = block_shape
325+
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n
326+
k_tiles_w1 = (k + block_k - 1) // block_k
327+
n_tiles_w2 = (k + block_n - 1) // block_n
328+
k_tiles_w2 = (n + block_k - 1) // block_k
329+
assert w1_s.shape == (e, n_tiles_w1, k_tiles_w1)
330+
assert w2_s.shape == (e, n_tiles_w2, k_tiles_w2)
331+
else:
332+
w1 = w1_16
333+
w2 = w2_16
334+
w1_s = None
335+
w2_s = None
336+
337+
return w1, w2, w1_s, w2_s, w1_16, w2_16

tests/kernels/quant_utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ def native_w8a8_block_matmul(
100100
As: torch.Tensor,
101101
Bs: torch.Tensor,
102102
block_size: list[int],
103-
output_dtype: torch.dtype
103+
output_dtype: torch.dtype,
104+
compute_type: torch.dtype = torch.float32,
104105
) -> torch.Tensor:
105106
"""This function performs matrix multiplication with block-wise
106107
quantization using native torch.
@@ -111,11 +112,6 @@ def native_w8a8_block_matmul(
111112
`Bs` (float32).
112113
The output is returned in the specified `output_dtype`.
113114
"""
114-
if A.dtype.itemsize <= 2:
115-
compute_type = torch.bfloat16
116-
else:
117-
compute_type = torch.float32
118-
119115
A = A.to(compute_type)
120116
B = B.to(compute_type)
121117
assert A.shape[-1] == B.shape[-1]

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from vllm.logger import init_logger
99
from vllm.model_executor.layers.fused_moe.utils import (
1010
_resize_cache, per_token_group_quant_fp8)
11+
from vllm.model_executor.layers.fused_moe.config import (
12+
FusedMoEQuantConfig)
13+
1114

1215
logger = init_logger(__name__)
1316

@@ -35,9 +38,11 @@ def __init__(
3538

3639
assert self.block_shape == [self.DEEPGEMM_BLOCK_SHAPE, self.DEEPGEMM_BLOCK_SHAPE]
3740
super().__init__(
38-
quant_dtype=torch.float8_e4m3fn,
39-
per_act_token_quant=False,
40-
block_shape=block_shape,
41+
FusedMoEQuantConfig(
42+
quant_dtype=torch.float8_e4m3fn,
43+
per_act_token_quant=False,
44+
block_shape=block_shape,
45+
)
4146
)
4247
self.max_num_tokens = max_num_tokens
4348
self.world_size = world_size

0 commit comments

Comments
 (0)