Skip to content

Commit 468d166

Browse files
committed
cleanup quantization
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 909f234 commit 468d166

File tree

5 files changed

+96
-78
lines changed

5 files changed

+96
-78
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -270,21 +270,23 @@ def batched_moe(
270270
topk_ids: torch.Tensor,
271271
w1_scale: Optional[torch.Tensor] = None,
272272
w2_scale: Optional[torch.Tensor] = None,
273-
use_fp8_w8a8: bool = False,
273+
qtype: Optional[torch.dtype] = None,
274274
block_shape: Optional[list[int]] = None,
275+
per_act_token: bool = False,
275276
) -> torch.Tensor:
276277
max_num_tokens = round_up(a.shape[0], 64) # ?
277278
fused_experts = FusedMoEModularKernel(
278279
BatchedPrepareAndFinalize(max_num_tokens,
279280
world_size=1,
280281
dp_size=1,
281282
rank=0,
282-
use_fp8_w8a8=use_fp8_w8a8,
283-
block_shape=block_shape),
283+
qtype=qtype,
284+
block_shape=block_shape,
285+
per_act_token=False),
284286
BatchedTritonExperts(max_num_tokens=max_num_tokens,
285287
dp_size=1,
286288
world_size=1,
287-
use_fp8_w8a8=use_fp8_w8a8,
289+
use_fp8_w8a8=qtype == torch.float8_e4m3fn,
288290
block_shape=block_shape))
289291

290292
return fused_experts(a,
@@ -360,7 +362,7 @@ def torch_moe2(
360362
@pytest.mark.parametrize("k", [128, 512, 1024])
361363
@pytest.mark.parametrize("e", NUM_EXPERTS)
362364
@pytest.mark.parametrize("topk", TOP_KS)
363-
@pytest.mark.parametrize("dtype", [torch.torch.float8_e4m3fn, torch.bfloat16])
365+
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
364366
def test_fused_moe_batched_experts(
365367
m: int,
366368
n: int,
@@ -378,6 +380,7 @@ def test_fused_moe_batched_experts(
378380
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
379381

380382
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
383+
qtype = dtype if dtype == torch.torch.float8_e4m3fn else None
381384

382385
if use_fp8_w8a8:
383386
block_n, block_k = block_shape[0], block_shape[1]
@@ -409,7 +412,7 @@ def test_fused_moe_batched_experts(
409412
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s,
410413
w2_s, use_fp8_w8a8, block_shape)
411414
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
412-
w2_s, use_fp8_w8a8, block_shape)
415+
w2_s, qtype, block_shape)
413416

414417
torch.testing.assert_close(baseline_output,
415418
batched_output,

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 75 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
1010
from vllm.model_executor.layers.fused_moe.fused_moe import (
1111
get_config_dtype_str, try_get_optimal_moe_config)
12-
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
13-
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
14-
per_token_group_quant_fp8)
12+
from vllm.model_executor.layers.fused_moe.utils import (
13+
_resize_cache,
14+
moe_kernel_quantize_input)
1515

1616

1717
@triton.jit
@@ -47,6 +47,7 @@ def moe_mmk(
4747
compute_type: tl.constexpr,
4848
use_w8a8: tl.constexpr,
4949
use_w8a16: tl.constexpr):
50+
5051
offs_k = tl.arange(0, BLOCK_K)
5152

5253
if use_w8a16:
@@ -325,6 +326,7 @@ def invoke_moe_batched_triton_kernel(
325326
use_int4_w4a16: bool,
326327
config: dict[str, int],
327328
block_shape: Optional[list[int]] = None):
329+
328330
assert not use_int4_w4a16
329331
max_num_tokens = A.size(1)
330332
K = A.size(2)
@@ -393,15 +395,17 @@ def __init__(self,
393395
world_size: int,
394396
dp_size: int,
395397
rank: int,
396-
use_fp8_w8a8: bool = False,
398+
qtype: Optional[torch.dtype] = None,
399+
per_act_token: bool = False,
397400
block_shape: Optional[list[int]] = None):
398401
super().__init__()
399402
self.world_size = world_size
400403
self.dp_size = dp_size
401404
self.rank = rank
402405
self.max_num_tokens = max_num_tokens
403-
self.use_fp8_w8a8 = use_fp8_w8a8
406+
self.per_act_token = per_act_token
404407
self.block_shape = block_shape
408+
self.qtype = qtype
405409

406410
def prepare(
407411
self,
@@ -445,10 +449,10 @@ def prepare(
445449

446450
b_a1 = torch.zeros(
447451
(num_local_experts, self.max_num_tokens, hidden_dim),
448-
dtype=torch.float8_e4m3fn if self.use_fp8_w8a8 else a1.dtype,
452+
dtype=self.qtype if self.qtype is not None else a1.dtype,
449453
device=a1.device)
450454

451-
if self.use_fp8_w8a8:
455+
if self.qtype is not None:
452456
k_tiles = (hidden_dim + block_k - 1) // block_k
453457
b_a1_scale = torch.zeros(
454458
(num_local_experts, self.max_num_tokens, k_tiles),
@@ -465,10 +469,20 @@ def prepare(
465469
rows = torch.count_nonzero(topks.flatten())
466470
rhs = a1[:topks.numel()][topks]
467471
idx = expert_id - first_expert
468-
if self.use_fp8_w8a8:
469-
# TODO: use _fp8_quantize
470-
b_a1[idx, :rows, :], b_a1_scale[
471-
idx, :rows] = per_token_group_quant_fp8(rhs, block_k)
472+
if self.qtype is not None:
473+
if a1_scale is not None:
474+
rhs_a1_scale = a1_scale[:topks.numel()][topks]
475+
else:
476+
rhs_a1_scale = None
477+
b_a1[idx, :rows, :], b_a1_scale[idx, :rows] = (
478+
moe_kernel_quantize_input(
479+
rhs,
480+
rhs_a1_scale,
481+
self.qtype,
482+
self.per_act_token,
483+
self.block_shape,
484+
)
485+
)
472486
else:
473487
b_a1[idx, :rows, :] = rhs
474488

@@ -524,7 +538,6 @@ def __init__(
524538
block_m: Optional[int] = None,
525539
):
526540
super().__init__()
527-
#assert block_shape is None
528541
assert block_m is None
529542
assert not use_int8_w8a8, "NYI"
530543
assert not use_int8_w8a16, "NYI"
@@ -615,6 +628,42 @@ def apply(
615628
return out
616629

617630

631+
def batched_moe_kernel_quantize_input(
632+
A: torch.Tensor,
633+
A_scale: Optional[torch.Tensor],
634+
num_tokens: int,
635+
E: int,
636+
N: int,
637+
expert_num_tokens: torch.Tensor,
638+
qtype: Optional[torch.dtype],
639+
per_channel_quant: bool,
640+
block_shape: Optional[list[int]] = None,
641+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
642+
if qtype is not None:
643+
assert block_shape is not None
644+
A_q = torch.empty_like(A, dtype=qtype)
645+
block_n, block_k = block_shape
646+
n_tiles = ((N // 2) + block_n - 1) // block_n
647+
scale_shape = (E, num_tokens, n_tiles)
648+
A_q_scale = torch.empty(scale_shape,
649+
dtype=torch.float32,
650+
device=A.device)
651+
for e in range(E):
652+
num_tokens = expert_num_tokens[e]
653+
if num_tokens > 0:
654+
A_q[e, :num_tokens, :], tmp_scale = moe_kernel_quantize_input(
655+
A[e, :num_tokens],
656+
A_scale[e, :num_tokens] if A_scale else None,
657+
qtype,
658+
per_channel_quant,
659+
[block_k, block_n])
660+
A_q_scale[e, :tmp_scale.shape[0]] = tmp_scale
661+
662+
return A_q, A_q_scale
663+
else:
664+
return A, A_scale
665+
666+
618667
class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
619668
"""
620669
A Triton based MoE expert class that operates on expert batched format,
@@ -630,6 +679,7 @@ def __init__(
630679
use_int8_w8a16: bool = False,
631680
use_int4_w4a16: bool = False,
632681
block_shape: Optional[list[int]] = None,
682+
per_act_token: bool = False,
633683
world_size: int = 1,
634684
dp_size: int = 1,
635685
):
@@ -644,6 +694,8 @@ def __init__(
644694
assert not use_int4_w4a16, "NYI"
645695
self.world_size = world_size
646696
self.dp_size = dp_size
697+
self.per_act_token = per_act_token
698+
self.qtype = torch.float8_e4m3fn if self.use_fp8_w8a8 else None
647699

648700
def workspace_shapes(
649701
self,
@@ -731,7 +783,6 @@ def apply(
731783
raise ValueError(
732784
f"Unsupported compute_type: {hidden_states.dtype}")
733785

734-
#print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}")
735786
# We can reuse the memory between these because by the time we need
736787
# cache3, we're done with cache1
737788
intermediate_cache1 = _resize_cache(workspace13, (E, num_tokens, N))
@@ -761,36 +812,17 @@ def apply(
761812
self.activation(activation, intermediate_cache2.view(-1, N // 2),
762813
intermediate_cache1.view(-1, N))
763814

764-
#qintermediate_cache2 = intermediate_cache2
765-
766-
# TODO (varun) : support w8a8
767-
#assert not self.use_fp8_w8a8
768-
if self.use_fp8_w8a8:
769-
per_act_token = False
770-
# TODO: reuse?
771-
qintermediate_cache2 = torch.empty_like(intermediate_cache2,
772-
dtype=torch.float8_e4m3fn)
773-
block_n = self.block_shape[0]
774-
n_tiles = ((N // 2) + block_n - 1) // block_n
775-
scale_shape = (E, num_tokens, n_tiles)
776-
a2q_scale = torch.empty(scale_shape,
777-
dtype=torch.float32,
778-
device=hidden_states.device)
779-
for e in range(E):
780-
num_tokens = expert_num_tokens[e]
781-
if num_tokens > 0:
782-
#qintermediate_cache2[e], tmp_scale = _fp8_quantize(
783-
# intermediate_cache2[e],
784-
# a2_scale[e] if a2_scale is not None else None,
785-
# per_act_token, self.block_shape)
786-
qintermediate_cache2[
787-
e, :
788-
num_tokens, :], tmp_scale = per_token_group_quant_fp8(
789-
intermediate_cache2[e, :num_tokens], block_n)
790-
a2q_scale[e, :tmp_scale.shape[0]] = tmp_scale
791-
else:
792-
qintermediate_cache2 = intermediate_cache2
793-
a2q_scale = a2_scale
815+
qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input(
816+
intermediate_cache2,
817+
a2_scale,
818+
num_tokens,
819+
E,
820+
N,
821+
expert_num_tokens,
822+
self.qtype,
823+
self.per_act_token,
824+
self.block_shape
825+
)
794826

795827
invoke_moe_batched_triton_kernel(A=qintermediate_cache2,
796828
B=w2,

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1520,11 +1520,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
15201520

15211521
def __init__(
15221522
self,
1523-
use_fp8_w8a8: bool,
1524-
use_int8_w8a8: bool,
1525-
use_int8_w8a16: bool,
1526-
use_int4_w4a16: bool,
1527-
per_channel_quant: bool,
1523+
use_fp8_w8a8: bool = False,
1524+
use_int8_w8a8: bool = False,
1525+
use_int8_w8a16: bool = False,
1526+
use_int4_w4a16: bool = False,
1527+
per_channel_quant: bool = False,
15281528
block_shape: Optional[list[int]] = None,
15291529
block_m: Optional[int] = None,
15301530
):

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ class MoEConfig:
192192
num_local_experts: int
193193
moe_parallel_config: FusedMoEParallelConfig
194194

195-
in_dtype: torch.dtype # The activation type.
195+
in_dtype: torch.dtype # The post quantization activation type.
196196

197197
# TODO: add more quantization params, blocked, per-token, etc.
198198
block_size: int = 128
@@ -489,22 +489,10 @@ def set_prepare_finalize(
489489
max_num_tokens=MOE_DP_CHUNK_SIZE,
490490
world_size=world_size,
491491
dp_size=dp_size,
492-
use_fp8_w8a8=False, #moe.in_dtype == torch.float8_e4m3fn,
493-
use_int8_w8a8=False,
494-
use_int8_w8a16=False,
495-
use_int4_w4a16=False,
496-
block_shape=None,
497492
)
498493
else:
499494
logger.debug("TritonExperts %s", self.moe)
500-
experts = TritonExperts(
501-
use_fp8_w8a8=False,
502-
use_int8_w8a8=False,
503-
use_int8_w8a16=False,
504-
use_int4_w4a16=False,
505-
block_shape=None,
506-
per_channel_quant=False,
507-
)
495+
experts = TritonExperts()
508496

509497
self.fused_experts = FusedMoEModularKernel(
510498
prepare_finalize,
@@ -827,8 +815,7 @@ def __init__(
827815
from vllm_hpu_extension.ops import DynamicFusedMOE
828816
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
829817

830-
logger.debug(f"PARAM DTYPE = {params_dtype}")
831-
#assert params_dtype.itemsize == 1
818+
logger.debug("Model dtype = %s", vllm_config.model_config.dtype)
832819

833820
moe = MoEConfig(
834821
max_num_tokens=MOE_DP_CHUNK_SIZE,
@@ -838,7 +825,6 @@ def __init__(
838825
num_local_experts=self.local_num_experts,
839826
moe_parallel_config=self.moe_parallel_config,
840827
in_dtype=moe.in_dtype,
841-
max_num_tokens=MOE_DP_CHUNK_SIZE,
842828
)
843829
self.moe_config = moe
844830
self.quant_config = quant_config
@@ -877,15 +863,14 @@ def __init__(
877863
self.batched_hidden_states: Optional[torch.Tensor] = None
878864
self.batched_router_logits: Optional[torch.Tensor] = None
879865
if self.moe_parallel_config.use_pplx_kernels:
880-
act_dtype = vllm_config.model_config.dtype
881866
self.batched_hidden_states = torch.zeros(
882867
(MOE_DP_CHUNK_SIZE, self.hidden_size),
883-
dtype=act_dtype,
868+
dtype=moe.in_dtype,
884869
device=torch.cuda.current_device())
885870

886871
self.batched_router_logits = torch.zeros(
887872
(MOE_DP_CHUNK_SIZE, self.global_num_experts),
888-
dtype=act_dtype,
873+
dtype=moe.in_dtype,
889874
device=torch.cuda.current_device())
890875

891876
@property

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -782,11 +782,9 @@ def select_gemm_impl(self, prepare_finalize):
782782
max_num_tokens=MOE_DP_CHUNK_SIZE,
783783
world_size=world_size,
784784
dp_size=dp_size,
785-
use_fp8_w8a8=True,
786-
use_int8_w8a8=False,
787-
use_int8_w8a16=False,
788-
use_int4_w4a16=False,
785+
qtype=torch.float8_e4m3fn,
789786
block_shape=self.quant_config.weight_block_size,
787+
per_act_token=False, #?
790788
)
791789
else:
792790
logger.debug("TritonOrDeepGemmExperts(fp8)")

0 commit comments

Comments
 (0)