Skip to content

Commit 53fa457

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
andauthored
[Misc] Add unit tests for MoE ModularKernel combinations + Profiling utility (#20449)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
1 parent 6fb1624 commit 53fa457

File tree

15 files changed

+1727
-22
lines changed

15 files changed

+1727
-22
lines changed

tests/kernels/moe/modular_kernel_tools/__init__.py

Whitespace-only changes.
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import argparse
5+
6+
import torch
7+
8+
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
9+
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
10+
11+
from .common import Config
12+
from .mk_objects import (MK_ALL_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES,
13+
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES)
14+
15+
16+
def make_config_arg_parser(description: str):
17+
18+
def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalize:
19+
for pf in MK_ALL_PREPARE_FINALIZE_TYPES:
20+
if pf.__name__ == s:
21+
return pf
22+
raise ValueError(
23+
f"Cannot find a PrepareFinalize type that matches {s}")
24+
25+
def to_experts_class_type(s: str) -> mk.FusedMoEPermuteExpertsUnpermute:
26+
for fe in MK_FUSED_EXPERT_TYPES:
27+
if fe.__name__ == s:
28+
return fe
29+
raise ValueError(f"Cannot find a FusedExperts type that matches {s}")
30+
31+
def to_quant_torch_dtype(s: str) -> torch.dtype:
32+
if s == "torch.float8_e4m3fn":
33+
return torch.float8_e4m3fn
34+
raise ValueError(f"Unsupported quant type {s}")
35+
36+
parser = argparse.ArgumentParser(description=description)
37+
38+
parser.add_argument(
39+
"--world-size",
40+
type=int,
41+
default=2,
42+
help="Number of ranks that participate in all2all",
43+
)
44+
parser.add_argument(
45+
"--pf-type",
46+
type=to_pf_class_type,
47+
required=True,
48+
help=("Choose a PrepareFinalize Type : "
49+
f"{[x.__name__ for x in MK_ALL_PREPARE_FINALIZE_TYPES]}"),
50+
)
51+
parser.add_argument(
52+
"--experts-type",
53+
type=to_experts_class_type,
54+
required=True,
55+
help=(f"Choose a FusedExpert type : "
56+
f"{[x.__name__ for x in MK_FUSED_EXPERT_TYPES]}"),
57+
)
58+
parser.add_argument(
59+
"-m",
60+
nargs="+",
61+
type=int,
62+
default=[64],
63+
help="num tokens per rank",
64+
)
65+
parser.add_argument(
66+
"-k",
67+
type=int,
68+
default=7168,
69+
help="hidden-size",
70+
)
71+
parser.add_argument(
72+
"-n",
73+
type=int,
74+
default=1024,
75+
help="N dimension of the first fused-moe matmul",
76+
)
77+
parser.add_argument("--num-experts",
78+
type=int,
79+
default=32,
80+
help="Global num experts")
81+
parser.add_argument("--topk",
82+
nargs="+",
83+
type=int,
84+
default=[4, 1],
85+
help="num topk")
86+
parser.add_argument(
87+
"--fused-moe-chunk-size",
88+
nargs="+",
89+
type=int,
90+
help="Fused moe chunk size used for the non-batched fused experts impl."
91+
)
92+
93+
# Quant args
94+
parser.add_argument("--quant-dtype",
95+
type=to_quant_torch_dtype,
96+
help="Quant datatype")
97+
parser.add_argument("--per-token-quantized-activations",
98+
action='store_true',
99+
help=("The input activations must be per-token "
100+
"quantized"))
101+
parser.add_argument("--per-channel-quantized-weights",
102+
action="store_true",
103+
help="The weights must be per-channel quantized.")
104+
parser.add_argument("--block-shape",
105+
nargs="+",
106+
type=int,
107+
help="Quantization block shape")
108+
109+
# Torch trace profile generation args
110+
parser.add_argument("--torch-trace-dir-path",
111+
type=str,
112+
default=None,
113+
help="Get torch trace for single execution")
114+
115+
return parser
116+
117+
118+
def _validate_args(args: argparse.Namespace):
119+
120+
if args.quant_dtype is not None:
121+
assert args.quant_dtype == torch.float8_e4m3fn
122+
if args.block_shape is not None:
123+
assert len(args.block_shape) == 2, (
124+
f"block shape must have 2 elements. got {args.block_shape}")
125+
126+
if args.experts_type in MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES:
127+
assert args.world_size == 1, (
128+
"Single GPU objects need world size set to 1")
129+
130+
if args.torch_trace_dir_path is not None:
131+
from pathlib import Path
132+
assert Path(args.torch_trace_dir_path).is_dir(), (
133+
f"Please create {args.torch_trace_dir_path}")
134+
135+
136+
def make_config(args: argparse.Namespace) -> Config:
137+
138+
_validate_args(args)
139+
140+
quant_config = None
141+
if args.quant_dtype is not None:
142+
quant_config = FusedMoEQuantConfig(
143+
quant_dtype=args.quant_dtype,
144+
per_act_token_quant=args.per_token_quantized_activations,
145+
per_out_ch_quant=args.per_channel_quantized_weights,
146+
block_shape=args.block_shape)
147+
148+
return Config(
149+
Ms=args.m,
150+
K=args.k,
151+
N=args.n,
152+
E=args.num_experts,
153+
topks=args.topk,
154+
dtype=torch.bfloat16, # hard-code
155+
quant_config=quant_config,
156+
prepare_finalize_type=args.pf_type,
157+
fused_experts_type=args.experts_type,
158+
fused_moe_chunk_size=args.fused_moe_chunk_size,
159+
world_size=args.world_size,
160+
torch_trace_dir_path=args.torch_trace_dir_path)

0 commit comments

Comments
 (0)