Skip to content

Commit 105e655

Browse files
committed
Re-add 2stage moe
1 parent a9af7a9 commit 105e655

File tree

5 files changed

+62
-15
lines changed

5 files changed

+62
-15
lines changed

vllm/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
8686
VLLM_ROCM_USE_AITER_LINEAR: bool = True
8787
VLLM_ROCM_USE_AITER_MOE: bool = True
88+
VLLM_ROCM_USE_AITER_2STAGE_MOE: bool = True
8889
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
8990
VLLM_ROCM_USE_AITER_MLA: bool = True
9091
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
@@ -598,6 +599,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
598599
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in
599600
("true", "1")),
600601

602+
# use aiter ck fused moe op if ater ops are enabled
603+
"VLLM_ROCM_USE_AITER_2STAGE_MOE":
604+
lambda: (os.getenv("VLLM_ROCM_USE_AITER_2STAGE_MOE", "True").lower() in
605+
("true", "1")),
606+
601607
# use aiter rms norm op if aiter ops are enabled.
602608
"VLLM_ROCM_USE_AITER_RMSNORM":
603609
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,15 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
121121
requires_grad=False)
122122
# Lazy import to avoid importing triton.
123123
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
124-
is_rocm_aiter_moe_enabled, shuffle_weights)
124+
is_rocm_aiter_2stage_moe_enabled, is_rocm_aiter_moe_enabled,
125+
shuffle_weights)
125126
if is_rocm_aiter_moe_enabled():
126127
# reshaping weights is required for aiter moe kernel.
127-
shuffled_w13, shuffled_w2 = shuffle_weights(
128-
layer.w13_weight.data, layer.w2_weight.data)
128+
layout = (32, 32) if is_rocm_aiter_2stage_moe_enabled() else (16,
129+
16)
130+
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data,
131+
layer.w2_weight.data,
132+
layout=layout)
129133

130134
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
131135
requires_grad=False)

vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ def is_rocm_aiter_moe_enabled() -> bool:
1313
and envs.VLLM_ROCM_USE_AITER
1414

1515

16+
def is_rocm_aiter_2stage_moe_enabled() -> bool:
17+
return current_platform.is_rocm() \
18+
and envs.VLLM_ROCM_USE_AITER_2STAGE_MOE \
19+
and envs.VLLM_ROCM_USE_AITER
20+
21+
1622
def rocm_aiter_asm_moe_tkw1(hidden_states,
1723
w1,
1824
w2,
@@ -165,6 +171,17 @@ def rocm_aiter_fused_experts(
165171
elif use_fp8_w8a8:
166172
assert not apply_router_weight_on_input, (
167173
"apply_router_weight_on_input is not supported for fp8_w8a8")
174+
if is_rocm_aiter_2stage_moe_enabled():
175+
from aiter.fused_moe_bf16_asm import ck_moe_2stages
176+
return ck_moe_2stages(a1=hidden_states,
177+
w1=w1,
178+
w2=w2,
179+
topk_weight=topk_weights,
180+
topk_ids=topk_ids,
181+
fc1_scale=w1_scale,
182+
fc2_scale=w2_scale,
183+
a1_scale=a1_scale,
184+
a2_scale=a2_scale)
168185
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
169186
w1=w1,
170187
w2=w2,
@@ -187,7 +204,17 @@ def rocm_aiter_fused_experts(
187204
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
188205
topk_ids = topk_ids.to(torch.int32)
189206
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)
190-
207+
if is_rocm_aiter_2stage_moe_enabled():
208+
from aiter.fused_moe_bf16_asm import ck_moe_2stages
209+
return ck_moe_2stages(a1=hidden_states,
210+
w1=w1,
211+
w2=w2,
212+
topk_weight=topk_weights,
213+
topk_ids=topk_ids,
214+
fc1_scale=w1_scale,
215+
fc2_scale=w2_scale,
216+
a1_scale=a1_scale,
217+
a2_scale=a2_scale)
191218
return rocm_aiter.ck_moe(hidden_states=hidden_states,
192219
w1=w1,
193220
w2=w2,
@@ -207,7 +234,8 @@ def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
207234
return topk_weights, topk_indices
208235

209236

210-
def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
237+
def shuffle_weights(*tensors: torch.Tensor,
238+
layout: tuple[int, int]) -> tuple[torch.Tensor, ...]:
211239
"""
212240
Applies shuffle_weight function from AITER to each
213241
input tensor and returns them.
@@ -220,7 +248,7 @@ def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
220248
"""
221249
from aiter.ops.shuffle import shuffle_weight
222250

223-
return tuple(shuffle_weight(tensor) for tensor in tensors)
251+
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
224252

225253

226254
def expand_weights(*tensors: torch.Tensor,

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,8 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
585585
def process_weights_after_loading(self, layer: Module) -> None:
586586
# Lazy import to avoid importing triton too early.
587587
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
588-
expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights)
588+
expand_weights, is_rocm_aiter_2stage_moe_enabled,
589+
is_rocm_aiter_moe_enabled, shuffle_weights)
589590

590591
# TODO (rob): refactor block quant into separate class.
591592
if self.block_quant:
@@ -615,7 +616,9 @@ def process_weights_after_loading(self, layer: Module) -> None:
615616
if is_rocm_aiter_moe_enabled():
616617
# reshaping weights is required for aiter moe kernel.
617618
shuffled_w13, shuffled_w2 = shuffle_weights(
618-
layer.w13_weight.data, layer.w2_weight.data)
619+
layer.w13_weight.data,
620+
layer.w2_weight.data,
621+
layout=(16, 16))
619622

620623
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
621624
requires_grad=False)
@@ -673,9 +676,12 @@ def process_weights_after_loading(self, layer: Module) -> None:
673676
w13_scales.contiguous(), requires_grad=False)
674677
layer.w2_weight_scale = torch.nn.Parameter(
675678
w2_scales.contiguous(), requires_grad=False)
676-
677-
shuffled_w13, shuffled_w2 = shuffle_weights(
678-
layer.w13_weight, layer.w2_weight)
679+
layout = (32,
680+
32) if is_rocm_aiter_2stage_moe_enabled() else (16,
681+
16)
682+
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight,
683+
layer.w2_weight,
684+
layout=layout)
679685

680686
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
681687
requires_grad=False)
@@ -759,9 +765,12 @@ def process_weights_after_loading(self, layer: Module) -> None:
759765
expansion_dims=expansion_dims)
760766
layer.w2_weight_scale = torch.nn.Parameter(
761767
w2_scales.contiguous(), requires_grad=False)
762-
763-
shuffled_w13, shuffled_w2 = shuffle_weights(
764-
layer.w13_weight, layer.w2_weight)
768+
layout = (32,
769+
32) if is_rocm_aiter_2stage_moe_enabled() else (16,
770+
16)
771+
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight,
772+
layer.w2_weight,
773+
layout=layout)
765774

766775
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
767776
requires_grad=False)

vllm/model_executor/layers/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def rocm_unquantized_gemm(x: torch.Tensor,
8787
out = ops.wvSplitK(weight, x_view, cu_count)
8888
return out.view(*x.shape[:-1], weight.shape[0])
8989
elif m % 4 == 0 and n == 1 and k <= 8192:
90-
out = ops.LLMM1(weight, x_view, out, 4)
90+
out = ops.LLMM1(weight, x_view, 4)
9191
return out.view(*x.shape[:-1], weight.shape[0])
9292
return torch.nn.functional.linear(x, weight, bias)
9393

0 commit comments

Comments
 (0)