Skip to content

[CustomOP][Refactor] Register CustomOP instead of overwrite forward_oot #1647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions vllm_ascend/ops/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,24 @@
from vllm_ascend.utils import is_310p


def silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor:
import torch_npu
@QuickGELU.register_oot
class AscendQuickGELU(QuickGELU):

if is_310p():
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
else:
out = torch_npu.npu_swiglu(x)
return out
def forward_oot(self, x: torch.tensor) -> torch.Tensor:
import torch_npu

out = torch_npu.npu_fast_gelu(x)
return out

def quick_gelu_forward_oot(self, x: torch.tensor) -> torch.Tensor:
import torch_npu

out = torch_npu.npu_fast_gelu(x)
return out
@SiluAndMul.register_oot
class AscendSiluAndMul(SiluAndMul):

def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
import torch_npu

QuickGELU.forward_oot = quick_gelu_forward_oot
SiluAndMul.forward_oot = silu_and_mul_forward_oot
if is_310p():
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
else:
out = torch_npu.npu_swiglu(x)
return out
121 changes: 60 additions & 61 deletions vllm_ascend/ops/common_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,76 +26,75 @@
select_experts)
from vllm_ascend.utils import is_310p

original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__

@UnquantizedFusedMoEMethod.register_oot
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
"""This UnquantizedFusedMoEMethod is used for qwen3-moe.
Customize it mainly to support aclgraph
"""

def unquantized_fused_moe_init_func(self, *args, **kwargs):
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
vllm_config = get_current_vllm_config()
self.max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.use_aclgraph = vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not vllm_config.model_config.enforce_eager
def __init__(self, *args, **kwargs):
super().__init__(self, *args, **kwargs)
vllm_config = get_current_vllm_config()
self.max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.use_aclgraph = vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not vllm_config.model_config.enforce_eager

def forward_oot(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
global_num_experts=global_num_experts,
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)

def forward_oot(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
global_num_experts=global_num_experts,
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
if topk_ids.shape[1] < top_k or is_310p():
assert global_num_experts is not None
return fused_experts_moge(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)

if topk_ids.shape[1] < top_k or is_310p():
assert global_num_experts is not None
return fused_experts_moge(
# If use aclgraph, we need to set max_num_tokens to make
# the input shape of `npu_moe_init_routing` fixed
max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None

return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)

# If use aclgraph, we need to set max_num_tokens to make
# the input shape of `npu_moe_init_routing` fixed
max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None

return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
max_num_tokens=max_num_tokens)


UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
UnquantizedFusedMoEMethod.forward_oot = forward_oot
apply_router_weight_on_input=apply_router_weight_on_input,
max_num_tokens=max_num_tokens)
4 changes: 2 additions & 2 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,7 @@ def select_experts(
return topk_weights, topk_ids


class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
class AscendDSUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):

def __init__(self, moe: FusedMoEConfig = None):

Expand Down Expand Up @@ -1200,7 +1200,7 @@ def __init__(
quant_config=quant_config)

if quant_config is None:
self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
self.quant_method = AscendDSUnquantizedFusedMoEMethod(moe)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)

Expand Down
49 changes: 25 additions & 24 deletions vllm_ascend/ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,28 @@
from vllm_ascend.utils import is_310p


def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu

if residual is not None:
if is_310p():
orig_dtype = residual.dtype
x = x + residual.to(x.dtype)
residual = x.to(orig_dtype)
x, _ = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
return x, residual

x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
return x


RMSNorm.forward_oot = forward_oot
@RMSNorm.register_oot
class AscendRMSNorm(RMSNorm):

def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu

if residual is not None:
if is_310p():
orig_dtype = residual.dtype
x = x + residual.to(x.dtype)
residual = x.to(orig_dtype)
x, _ = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
return x, residual

x, residual = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
return x
Loading
Loading