Skip to content

Commit 4500268

Browse files
committed
custom op register
Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent 89c1a0f commit 4500268

File tree

6 files changed

+266
-306
lines changed

6 files changed

+266
-306
lines changed

vllm_ascend/ops/activation.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,24 @@
2121
from vllm_ascend.utils import is_310p
2222

2323

24-
def silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor:
25-
import torch_npu
24+
@QuickGELU.register_oot
25+
class AscendQuickGELU(QuickGELU):
2626

27-
if is_310p():
28-
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
29-
else:
30-
out = torch_npu.npu_swiglu(x)
31-
return out
27+
def forward_oot(self, x: torch.tensor) -> torch.Tensor:
28+
import torch_npu
3229

30+
out = torch_npu.npu_fast_gelu(x)
31+
return out
3332

34-
def quick_gelu_forward_oot(self, x: torch.tensor) -> torch.Tensor:
35-
import torch_npu
3633

37-
out = torch_npu.npu_fast_gelu(x)
38-
return out
34+
@SiluAndMul.register_oot
35+
class AscendSiluAndMul(SiluAndMul):
3936

37+
def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
38+
import torch_npu
4039

41-
QuickGELU.forward_oot = quick_gelu_forward_oot
42-
SiluAndMul.forward_oot = silu_and_mul_forward_oot
40+
if is_310p():
41+
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
42+
else:
43+
out = torch_npu.npu_swiglu(x)
44+
return out

vllm_ascend/ops/common_fused_moe.py

Lines changed: 60 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -26,76 +26,75 @@
2626
select_experts)
2727
from vllm_ascend.utils import is_310p
2828

29-
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
3029

30+
@UnquantizedFusedMoEMethod.register_oot
31+
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
32+
"""This UnquantizedFusedMoEMethod is used for qwen3-moe.
33+
Customize it mainly to support aclgraph
34+
"""
3135

32-
def unquantized_fused_moe_init_func(self, *args, **kwargs):
33-
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
34-
vllm_config = get_current_vllm_config()
35-
self.max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
36-
self.use_aclgraph = vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not vllm_config.model_config.enforce_eager
36+
def __init__(self, *args, **kwargs):
37+
super().__init__(self, *args, **kwargs)
38+
vllm_config = get_current_vllm_config()
39+
self.max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
40+
self.use_aclgraph = vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not vllm_config.model_config.enforce_eager
3741

42+
def forward_oot(
43+
self,
44+
layer: torch.nn.Module,
45+
x: torch.Tensor,
46+
use_grouped_topk: bool,
47+
top_k: int,
48+
router_logits: torch.Tensor,
49+
renormalize: bool,
50+
topk_group: Optional[int] = None,
51+
num_expert_group: Optional[int] = None,
52+
custom_routing_function: Optional[Callable] = None,
53+
scoring_func: str = "softmax",
54+
e_score_correction_bias: Optional[torch.Tensor] = None,
55+
global_num_experts: Optional[int] = None,
56+
expert_map: Optional[torch.Tensor] = None,
57+
apply_router_weight_on_input: bool = False,
58+
activation: str = "silu",
59+
) -> torch.Tensor:
60+
topk_weights, topk_ids = select_experts(
61+
global_num_experts=global_num_experts,
62+
hidden_states=x,
63+
router_logits=router_logits,
64+
top_k=top_k,
65+
use_grouped_topk=use_grouped_topk,
66+
renormalize=renormalize,
67+
topk_group=topk_group,
68+
num_expert_group=num_expert_group,
69+
custom_routing_function=custom_routing_function,
70+
scoring_func=scoring_func,
71+
e_score_correction_bias=e_score_correction_bias,
72+
)
3873

39-
def forward_oot(
40-
self,
41-
layer: torch.nn.Module,
42-
x: torch.Tensor,
43-
use_grouped_topk: bool,
44-
top_k: int,
45-
router_logits: torch.Tensor,
46-
renormalize: bool,
47-
topk_group: Optional[int] = None,
48-
num_expert_group: Optional[int] = None,
49-
custom_routing_function: Optional[Callable] = None,
50-
scoring_func: str = "softmax",
51-
e_score_correction_bias: Optional[torch.Tensor] = None,
52-
global_num_experts: Optional[int] = None,
53-
expert_map: Optional[torch.Tensor] = None,
54-
apply_router_weight_on_input: bool = False,
55-
activation: str = "silu",
56-
) -> torch.Tensor:
57-
topk_weights, topk_ids = select_experts(
58-
global_num_experts=global_num_experts,
59-
hidden_states=x,
60-
router_logits=router_logits,
61-
top_k=top_k,
62-
use_grouped_topk=use_grouped_topk,
63-
renormalize=renormalize,
64-
topk_group=topk_group,
65-
num_expert_group=num_expert_group,
66-
custom_routing_function=custom_routing_function,
67-
scoring_func=scoring_func,
68-
e_score_correction_bias=e_score_correction_bias,
69-
)
74+
if topk_ids.shape[1] < top_k or is_310p():
75+
assert global_num_experts is not None
76+
return fused_experts_moge(
77+
hidden_states=x,
78+
w1=layer.w13_weight,
79+
w2=layer.w2_weight,
80+
topk_weights=topk_weights,
81+
topk_ids=topk_ids,
82+
top_k=top_k,
83+
global_num_experts=global_num_experts,
84+
expert_map=expert_map,
85+
apply_router_weight_on_input=apply_router_weight_on_input)
7086

71-
if topk_ids.shape[1] < top_k or is_310p():
72-
assert global_num_experts is not None
73-
return fused_experts_moge(
87+
# If use aclgraph, we need to set max_num_tokens to make
88+
# the input shape of `npu_moe_init_routing` fixed
89+
max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None
90+
91+
return fused_experts(
7492
hidden_states=x,
7593
w1=layer.w13_weight,
7694
w2=layer.w2_weight,
7795
topk_weights=topk_weights,
7896
topk_ids=topk_ids,
7997
top_k=top_k,
80-
global_num_experts=global_num_experts,
8198
expert_map=expert_map,
82-
apply_router_weight_on_input=apply_router_weight_on_input)
83-
84-
# If use aclgraph, we need to set max_num_tokens to make
85-
# the input shape of `npu_moe_init_routing` fixed
86-
max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None
87-
88-
return fused_experts(
89-
hidden_states=x,
90-
w1=layer.w13_weight,
91-
w2=layer.w2_weight,
92-
topk_weights=topk_weights,
93-
topk_ids=topk_ids,
94-
top_k=top_k,
95-
expert_map=expert_map,
96-
apply_router_weight_on_input=apply_router_weight_on_input,
97-
max_num_tokens=max_num_tokens)
98-
99-
100-
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
101-
UnquantizedFusedMoEMethod.forward_oot = forward_oot
99+
apply_router_weight_on_input=apply_router_weight_on_input,
100+
max_num_tokens=max_num_tokens)

vllm_ascend/ops/fused_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,7 @@ def select_experts(
938938
return topk_weights, topk_ids
939939

940940

941-
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
941+
class AscendDSUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
942942

943943
def __init__(self, moe: FusedMoEConfig = None):
944944

@@ -1185,7 +1185,7 @@ def __init__(
11851185
quant_config=quant_config)
11861186

11871187
if quant_config is None:
1188-
self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
1188+
self.quant_method = AscendDSUnquantizedFusedMoEMethod(moe)
11891189
else:
11901190
self.quant_method = quant_config.get_quant_method(self, prefix)
11911191

vllm_ascend/ops/layernorm.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,27 +23,28 @@
2323
from vllm_ascend.utils import is_310p
2424

2525

26-
def forward_oot(
27-
self,
28-
x: torch.Tensor,
29-
residual: Optional[torch.Tensor] = None,
30-
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
31-
import torch_npu
32-
33-
if residual is not None:
34-
if is_310p():
35-
orig_dtype = residual.dtype
36-
x = x + residual.to(x.dtype)
37-
residual = x.to(orig_dtype)
38-
x, _ = torch_npu.npu_rms_norm(x, self.weight,
39-
self.variance_epsilon)
40-
else:
41-
x, _, residual = torch_npu.npu_add_rms_norm(
42-
x, residual, self.weight, self.variance_epsilon)
43-
return x, residual
44-
45-
x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
46-
return x
47-
48-
49-
RMSNorm.forward_oot = forward_oot
26+
@RMSNorm.register_oot
27+
class AscendRMSNorm(RMSNorm):
28+
29+
def forward_oot(
30+
self,
31+
x: torch.Tensor,
32+
residual: Optional[torch.Tensor] = None,
33+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
34+
import torch_npu
35+
36+
if residual is not None:
37+
if is_310p():
38+
orig_dtype = residual.dtype
39+
x = x + residual.to(x.dtype)
40+
residual = x.to(orig_dtype)
41+
x, _ = torch_npu.npu_rms_norm(x, self.weight,
42+
self.variance_epsilon)
43+
else:
44+
x, _, residual = torch_npu.npu_add_rms_norm(
45+
x, residual, self.weight, self.variance_epsilon)
46+
return x, residual
47+
48+
x, residual = torch_npu.npu_rms_norm(x, self.weight,
49+
self.variance_epsilon)
50+
return x

0 commit comments

Comments
 (0)