Skip to content

Commit 5ebf667

Browse files
vllmellmtjtanaa
andauthored
[FEAT][ROCm] Integrate Fused MoE Kernels from AITER (#14967)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
1 parent 781d056 commit 5ebf667

File tree

9 files changed

+390
-65
lines changed

9 files changed

+390
-65
lines changed

tests/kernels/test_moe.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
44
Run `pytest tests/kernels/test_moe.py`.
55
"""
6-
76
import pytest
87
import torch
98
from torch.nn import Parameter
@@ -216,11 +215,17 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
216215
@pytest.mark.parametrize("dtype",
217216
[torch.float32, torch.float16, torch.bfloat16])
218217
@pytest.mark.parametrize("padding", [True, False])
218+
@pytest.mark.parametrize(
219+
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
219220
@torch.inference_mode()
220-
def test_mixtral_moe(dtype: torch.dtype, padding: bool):
221+
def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
222+
monkeypatch):
221223
"""Make sure our Mixtral MoE implementation agrees with the one from
222224
huggingface."""
223225

226+
if use_rocm_aiter:
227+
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
228+
224229
# Instantiate our and huggingface's MoE blocks
225230
config = MixtralConfig()
226231
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
@@ -268,10 +273,18 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool):
268273
torch.bfloat16: 1e-2,
269274
}
270275

271-
torch.testing.assert_close(hf_states.flatten(0, 1),
272-
vllm_states,
273-
rtol=mixtral_moe_tol[dtype],
274-
atol=mixtral_moe_tol[dtype])
276+
if use_rocm_aiter:
277+
# The values of rtol and atol are set based on the tests in ROCM AITER package. # noqa: E501
278+
# https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174 # noqa: E501
279+
torch.testing.assert_close(hf_states.flatten(0, 1),
280+
vllm_states,
281+
rtol=0.01,
282+
atol=100)
283+
else:
284+
torch.testing.assert_close(hf_states.flatten(0, 1),
285+
vllm_states,
286+
rtol=mixtral_moe_tol[dtype],
287+
atol=mixtral_moe_tol[dtype])
275288

276289

277290
@pytest.mark.parametrize("m", [1, 33, 64, 222])

tests/model_executor/test_enabled_custom_ops.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
from vllm.model_executor.layers.activation import (GeluAndMul,
88
ReLUSquaredActivation,
99
SiluAndMul)
10+
from vllm.model_executor.layers.fused_moe.fused_moe import (
11+
dispatch_fused_experts_func, dispatch_topk_func,
12+
torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts,
13+
vllm_topk_softmax)
1014
from vllm.model_executor.layers.layernorm import (
1115
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
1216
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
@@ -92,6 +96,38 @@ def test_enabled_ops_invalid(env: str):
9296
RMSNorm(1024).enabled()
9397

9498

99+
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
100+
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
101+
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
102+
topk_func = dispatch_topk_func()
103+
104+
if current_platform.is_rocm() and int(use_rocm_aiter):
105+
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
106+
rocm_aiter_topk_softmax)
107+
108+
assert topk_func == rocm_aiter_topk_softmax
109+
else:
110+
assert topk_func == vllm_topk_softmax
111+
112+
113+
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
114+
@pytest.mark.parametrize("inplace", [True, False])
115+
def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool,
116+
monkeypatch):
117+
118+
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
119+
fused_experts_func = dispatch_fused_experts_func(inplace)
120+
if current_platform.is_rocm() and int(use_rocm_aiter):
121+
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
122+
rocm_aiter_fused_experts)
123+
124+
assert fused_experts_func == rocm_aiter_fused_experts
125+
elif inplace:
126+
assert fused_experts_func == torch_vllm_inplace_fused_experts
127+
else:
128+
assert fused_experts_func == torch_vllm_outplace_fused_experts
129+
130+
95131
@pytest.mark.parametrize("add_residual", [True, False])
96132
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
97133
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])

tests/models/decoder_only/language/test_mistral.py

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -174,15 +174,8 @@
174174
@pytest.mark.parametrize("dtype", ["bfloat16"])
175175
@pytest.mark.parametrize("max_tokens", [64])
176176
@pytest.mark.parametrize("num_logprobs", [5])
177-
def test_models(
178-
hf_runner,
179-
vllm_runner,
180-
example_prompts,
181-
model: str,
182-
dtype: str,
183-
max_tokens: int,
184-
num_logprobs: int,
185-
) -> None:
177+
def test_models(hf_runner, vllm_runner, example_prompts, model: str,
178+
dtype: str, max_tokens: int, num_logprobs: int) -> None:
186179
# TODO(sang): Sliding window should be tested separately.
187180
with hf_runner(model, dtype=dtype) as hf_model:
188181
hf_outputs = hf_model.generate_greedy_logprobs_limit(
@@ -206,14 +199,8 @@ def test_models(
206199
@pytest.mark.parametrize("dtype", ["bfloat16"])
207200
@pytest.mark.parametrize("max_tokens", [64])
208201
@pytest.mark.parametrize("num_logprobs", [5])
209-
def test_mistral_format(
210-
vllm_runner,
211-
example_prompts,
212-
model: str,
213-
dtype: str,
214-
max_tokens: int,
215-
num_logprobs: int,
216-
) -> None:
202+
def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str,
203+
max_tokens: int, num_logprobs: int) -> None:
217204
with vllm_runner(
218205
model,
219206
dtype=dtype,
@@ -244,11 +231,8 @@ def test_mistral_format(
244231

245232
@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS)
246233
@pytest.mark.parametrize("dtype", ["bfloat16"])
247-
def test_mistral_symbolic_languages(
248-
vllm_runner,
249-
model: str,
250-
dtype: str,
251-
) -> None:
234+
def test_mistral_symbolic_languages(vllm_runner, model: str,
235+
dtype: str) -> None:
252236
with vllm_runner(model,
253237
dtype=dtype,
254238
max_model_len=8192,
@@ -266,11 +250,7 @@ def test_mistral_symbolic_languages(
266250
@pytest.mark.parametrize("dtype", ["bfloat16"])
267251
@pytest.mark.parametrize("model",
268252
MISTRAL_FORMAT_MODELS) # v1 can't do func calling
269-
def test_mistral_function_calling(
270-
vllm_runner,
271-
model: str,
272-
dtype: str,
273-
) -> None:
253+
def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
274254
with vllm_runner(model,
275255
dtype=dtype,
276256
tokenizer_mode="mistral",
@@ -301,11 +281,8 @@ def test_mistral_function_calling(
301281
@pytest.mark.parametrize("model", MODELS)
302282
@pytest.mark.parametrize("guided_backend",
303283
["outlines", "lm-format-enforcer", "xgrammar"])
304-
def test_mistral_guided_decoding(
305-
vllm_runner,
306-
model: str,
307-
guided_backend: str,
308-
) -> None:
284+
def test_mistral_guided_decoding(vllm_runner, model: str,
285+
guided_backend: str) -> None:
309286
with vllm_runner(model, dtype='bfloat16',
310287
tokenizer_mode="mistral") as vllm_model:
311288

tests/quantization/test_fp8.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,14 @@
2323
reason="FP8 is not supported on this GPU type.")
2424
@pytest.mark.parametrize("model_id", MODELS)
2525
@pytest.mark.parametrize("force_marlin", [False, True])
26+
@pytest.mark.parametrize(
27+
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
2628
def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
27-
monkeypatch) -> None:
29+
use_rocm_aiter: bool, monkeypatch) -> None:
30+
31+
if use_rocm_aiter:
32+
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
33+
2834
if force_marlin:
2935
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")
3036

@@ -47,7 +53,13 @@ def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
4753
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
4854
reason="FP8 is not supported on this GPU type.")
4955
@pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
50-
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, monkeypatch):
56+
@pytest.mark.parametrize(
57+
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
58+
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
59+
use_rocm_aiter: bool, monkeypatch):
60+
if use_rocm_aiter:
61+
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
62+
5163
# vllm_runner.apply_model() relies on V0 internals.
5264
monkeypatch.setenv("VLLM_USE_V1", "0")
5365
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
@@ -86,8 +98,13 @@ def check_model(model):
8698
reason="FP8 is not supported on this GPU type.")
8799
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
88100
@pytest.mark.parametrize("force_marlin", [False, True])
101+
@pytest.mark.parametrize(
102+
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
89103
def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
90-
monkeypatch) -> None:
104+
use_rocm_aiter: bool, monkeypatch) -> None:
105+
if use_rocm_aiter:
106+
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
107+
91108
# vllm_runner.apply_model() relies on V0 internals.
92109
monkeypatch.setenv("VLLM_USE_V1", "0")
93110

vllm/envs.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@
7373
VLLM_DISABLED_KERNELS: list[str] = []
7474
VLLM_USE_V1: bool = True
7575
VLLM_ROCM_USE_AITER: bool = False
76+
VLLM_ROCM_USE_AITER_MOE: bool = True
77+
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False
7678
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
7779
VLLM_ROCM_FP8_PADDING: bool = True
7880
VLLM_ROCM_MOE_PADDING: bool = True
@@ -513,6 +515,19 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
513515
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
514516
("true", "1")),
515517

518+
# Whether to use aiter moe ops.
519+
# By default is enabled.
520+
"VLLM_ROCM_USE_AITER_MOE":
521+
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in
522+
("true", "1")),
523+
524+
# Whether to use aiter block scaled moe kernel.
525+
# By default this is disabled.
526+
"VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE":
527+
lambda:
528+
(os.getenv("VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE", "false").lower() in
529+
("true", "1")),
530+
516531
# use aiter rms norm op if aiter ops are enabled.
517532
"VLLM_ROCM_USE_AITER_RMSNORM":
518533
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 69 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
from vllm.platforms import current_platform
1818
from vllm.utils import direct_register_custom_op
1919

20+
from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
21+
rocm_aiter_fused_experts,
22+
rocm_aiter_topk_softmax)
23+
2024
logger = init_logger(__name__)
2125

2226

@@ -1035,6 +1039,28 @@ def try_get_optimal_moe_config(
10351039
return config
10361040

10371041

1042+
def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
1043+
token_expert_indices: torch.Tensor,
1044+
gating_output: torch.Tensor,
1045+
renormalize: bool) -> tuple[torch.Tensor, ...]:
1046+
ops.topk_softmax(
1047+
topk_weights,
1048+
topk_indices,
1049+
token_expert_indices,
1050+
gating_output,
1051+
)
1052+
if renormalize:
1053+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
1054+
1055+
return topk_weights, topk_indices
1056+
1057+
1058+
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
1059+
if is_rocm_aiter_moe_enabled():
1060+
return rocm_aiter_topk_softmax
1061+
return vllm_topk_softmax
1062+
1063+
10381064
def fused_topk(
10391065
hidden_states: torch.Tensor,
10401066
gating_output: torch.Tensor,
@@ -1059,17 +1085,14 @@ def fused_topk(
10591085
dtype=torch.int32,
10601086
device=hidden_states.device)
10611087

1062-
ops.topk_softmax(
1063-
topk_weights,
1064-
topk_ids,
1065-
token_expert_indicies,
1066-
gating_output.float(), # TODO(woosuk): Optimize this.
1067-
)
1068-
del token_expert_indicies # Not used. Will be used in the future.
1088+
gating_output_float = gating_output.float() # TODO(woosuk): Optimize this.
10691089

1070-
if renormalize:
1071-
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
1090+
topk_func = dispatch_topk_func()
1091+
topk_weights, topk_ids = topk_func(topk_weights, topk_ids,
1092+
token_expert_indicies,
1093+
gating_output_float, renormalize)
10721094

1095+
del token_expert_indicies # Not used. Will be used in the future.
10731096
return topk_weights, topk_ids
10741097

10751098

@@ -1259,6 +1282,24 @@ def outplace_fused_experts_fake(
12591282
)
12601283

12611284

1285+
def torch_vllm_inplace_fused_experts(**kwargs) -> torch.Tensor:
1286+
torch.ops.vllm.inplace_fused_experts(**kwargs)
1287+
hidden_states = kwargs['hidden_states']
1288+
return hidden_states
1289+
1290+
1291+
def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
1292+
return torch.ops.vllm.outplace_fused_experts(**kwargs)
1293+
1294+
1295+
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
1296+
if is_rocm_aiter_moe_enabled():
1297+
return rocm_aiter_fused_experts
1298+
if inplace:
1299+
return torch_vllm_inplace_fused_experts
1300+
return torch_vllm_outplace_fused_experts
1301+
1302+
12621303
def fused_experts(hidden_states: torch.Tensor,
12631304
w1: torch.Tensor,
12641305
w2: torch.Tensor,
@@ -1278,20 +1319,25 @@ def fused_experts(hidden_states: torch.Tensor,
12781319
a1_scale: Optional[torch.Tensor] = None,
12791320
a2_scale: Optional[torch.Tensor] = None,
12801321
block_shape: Optional[List[int]] = None) -> torch.Tensor:
1281-
1282-
if inplace:
1283-
torch.ops.vllm.inplace_fused_experts(
1284-
hidden_states, w1, w2, topk_weights, topk_ids, activation,
1285-
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
1286-
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
1287-
block_shape)
1288-
return hidden_states
1289-
else:
1290-
return torch.ops.vllm.outplace_fused_experts(
1291-
hidden_states, w1, w2, topk_weights, topk_ids, activation,
1292-
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
1293-
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
1294-
block_shape)
1322+
return dispatch_fused_experts_func(inplace)(
1323+
hidden_states=hidden_states,
1324+
w1=w1,
1325+
w2=w2,
1326+
topk_weights=topk_weights,
1327+
topk_ids=topk_ids,
1328+
activation=activation,
1329+
use_fp8_w8a8=use_fp8_w8a8,
1330+
use_int8_w8a16=use_int8_w8a16,
1331+
use_int4_w4a16=use_int4_w4a16,
1332+
global_num_experts=global_num_experts,
1333+
expert_map=expert_map,
1334+
w1_scale=w1_scale,
1335+
w2_scale=w2_scale,
1336+
w1_zp=w1_zp,
1337+
w2_zp=w2_zp,
1338+
a1_scale=a1_scale,
1339+
a2_scale=a2_scale,
1340+
block_shape=block_shape)
12951341

12961342

12971343
def fused_experts_impl(hidden_states: torch.Tensor,

0 commit comments

Comments
 (0)