diff --git a/vllm_ascend/ops/activation.py b/vllm_ascend/ops/activation.py index 1c32643744..0abae22fde 100644 --- a/vllm_ascend/ops/activation.py +++ b/vllm_ascend/ops/activation.py @@ -21,22 +21,24 @@ from vllm_ascend.utils import is_310p -def silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor: - import torch_npu +@QuickGELU.register_oot +class AscendQuickGELU(QuickGELU): - if is_310p(): - out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16) - else: - out = torch_npu.npu_swiglu(x) - return out + def forward_oot(self, x: torch.tensor) -> torch.Tensor: + import torch_npu + out = torch_npu.npu_fast_gelu(x) + return out -def quick_gelu_forward_oot(self, x: torch.tensor) -> torch.Tensor: - import torch_npu - out = torch_npu.npu_fast_gelu(x) - return out +@SiluAndMul.register_oot +class AscendSiluAndMul(SiluAndMul): + def forward_oot(self, x: torch.Tensor) -> torch.Tensor: + import torch_npu -QuickGELU.forward_oot = quick_gelu_forward_oot -SiluAndMul.forward_oot = silu_and_mul_forward_oot \ No newline at end of file + if is_310p(): + out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16) + else: + out = torch_npu.npu_swiglu(x) + return out diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 3aa23a2a6e..453b8a7d9d 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -26,76 +26,75 @@ select_experts) from vllm_ascend.utils import is_310p -original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__ +@UnquantizedFusedMoEMethod.register_oot +class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): + """This UnquantizedFusedMoEMethod is used for qwen3-moe. + Customize it mainly to support aclgraph + """ -def unquantized_fused_moe_init_func(self, *args, **kwargs): - original_unquantized_fused_moe_init_func(self, *args, **kwargs) - vllm_config = get_current_vllm_config() - self.max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens - self.use_aclgraph = vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not vllm_config.model_config.enforce_eager + def __init__(self, *args, **kwargs): + super().__init__(self, *args, **kwargs) + vllm_config = get_current_vllm_config() + self.max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens + self.use_aclgraph = vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not vllm_config.model_config.enforce_eager + def forward_oot( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + global_num_experts: Optional[int] = None, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + topk_weights, topk_ids = select_experts( + global_num_experts=global_num_experts, + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) -def forward_oot( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - global_num_experts: Optional[int] = None, - expert_map: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", -) -> torch.Tensor: - topk_weights, topk_ids = select_experts( - global_num_experts=global_num_experts, - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - ) + if topk_ids.shape[1] < top_k or is_310p(): + assert global_num_experts is not None + return fused_experts_moge( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input) - if topk_ids.shape[1] < top_k or is_310p(): - assert global_num_experts is not None - return fused_experts_moge( + # If use aclgraph, we need to set max_num_tokens to make + # the input shape of `npu_moe_init_routing` fixed + max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None + + return fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, top_k=top_k, - global_num_experts=global_num_experts, expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input) - - # If use aclgraph, we need to set max_num_tokens to make - # the input shape of `npu_moe_init_routing` fixed - max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None - - return fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - max_num_tokens=max_num_tokens) - - -UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func -UnquantizedFusedMoEMethod.forward_oot = forward_oot + apply_router_weight_on_input=apply_router_weight_on_input, + max_num_tokens=max_num_tokens) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index aa189428a3..083903cd7b 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -941,7 +941,7 @@ def select_experts( return topk_weights, topk_ids -class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): +class AscendDSUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): def __init__(self, moe: FusedMoEConfig = None): @@ -1200,7 +1200,7 @@ def __init__( quant_config=quant_config) if quant_config is None: - self.quant_method = AscendUnquantizedFusedMoEMethod(moe) + self.quant_method = AscendDSUnquantizedFusedMoEMethod(moe) else: self.quant_method = quant_config.get_quant_method(self, prefix) diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 7b839fe3d0..587824aa6c 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -23,27 +23,28 @@ from vllm_ascend.utils import is_310p -def forward_oot( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, -) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - import torch_npu - - if residual is not None: - if is_310p(): - orig_dtype = residual.dtype - x = x + residual.to(x.dtype) - residual = x.to(orig_dtype) - x, _ = torch_npu.npu_rms_norm(x, self.weight, - self.variance_epsilon) - else: - x, _, residual = torch_npu.npu_add_rms_norm( - x, residual, self.weight, self.variance_epsilon) - return x, residual - - x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) - return x - - -RMSNorm.forward_oot = forward_oot +@RMSNorm.register_oot +class AscendRMSNorm(RMSNorm): + + def forward_oot( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + import torch_npu + + if residual is not None: + if is_310p(): + orig_dtype = residual.dtype + x = x + residual.to(x.dtype) + residual = x.to(orig_dtype) + x, _ = torch_npu.npu_rms_norm(x, self.weight, + self.variance_epsilon) + else: + x, _, residual = torch_npu.npu_add_rms_norm( + x, residual, self.weight, self.variance_epsilon) + return x, residual + + x, residual = torch_npu.npu_rms_norm(x, self.weight, + self.variance_epsilon) + return x diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 3dd91ea63f..7faf20ee8e 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -86,207 +86,165 @@ def rope_forward_oot( return query.view(query_shape), key.view(key_shape) -def native_rope_deepseek_forward(self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - max_seq_len: Optional[int] = None): - if max_seq_len is not None and max_seq_len > self.max_seq_len: - _set_cos_sin_cache(self, max_seq_len, query.device, query.dtype) - if len(key.shape) == 2: - key = key[:, None, :] - # Note: we implement the non neox_style method with shuffle the last dim and neox style - # calculation method which is also more compute friendly to the ascend machine - # https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py - neox_style = True - if self.is_neox_style is False: - b, h_q, d = query.shape - query = query.view(b, h_q, d // 2, 2).transpose(3, - 2).reshape(b, h_q, d) - b, h_k, d = key.shape - key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d) - q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets, - neox_style) - return q_pe, k_pe - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -# Inverse dim formula to find dim based on number of rotations -def yarn_find_correction_dim(num_rotations, - dim, - base=10000, - max_position_embeddings=2048): - # Note: use torch instead of math to solve MTP compilation error. - return (dim * torch.log( - torch.tensor(max_position_embeddings) / - (num_rotations * 2 * torch.pi))) / (2 * torch.log(torch.tensor(base))) - - -def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: - if scale <= 1: - return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 - - -# Find dim range bounds based on rotations -def yarn_find_correction_range(low_rot, - high_rot, - dim, - base=10000, - max_position_embeddings=2048): - # Note: use torch instead of math to solve MTP compilation error. - low = torch.floor( - yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) - high = torch.ceil( - yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) - # Note: use torch instead of max/min to solve MTP compilation error. - return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1) - - -def yarn_linear_ramp_mask(min_value, max_value, dim): - # Note: The if conditional branch is not used here - # to solve MTP compilation error. - max_value += (min_value == max_value).float() * 0.001 - linear_func = (torch.arange(dim, dtype=torch.float32) - - min_value) / (max_value - min_value) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos[position_ids] - sin = sin[position_ids] - cos = cos[:, None, None, :] - sin = sin[:, None, None, :] - - if len(q.shape) == 3: - q = q[:, :, None, :] - if len(k.shape) == 2: - k = k[:, None, None, :] - elif len(k.shape) == 3: - k = k[:, :, None, :] - - b, h_q, s, d = q.shape - q = q.view(b, h_q, s, d // 2, 2).transpose(4, 3).reshape(b, h_q, s, d) - - b, h_k, s, d = k.shape - k = k.view(b, h_k, s, d // 2, 2).transpose(4, 3).reshape(b, h_k, s, d) - - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - - q_embed = q_embed.view(b, h_q, d) - k_embed = k_embed.view(b, h_k, d) - - return q_embed, k_embed - - -def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - dim = self.rotary_dim - - freq_extra = 1.0 / (self.base**( - torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) - freq_inter = 1.0 / (self.scaling_factor * self.base**( - torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) - - low, high = yarn_find_correction_range( - self.beta_fast, - self.beta_slow, - dim, - self.base, - self.max_position_embeddings, - ) - inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( - device=device, dtype=torch.float32) - inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask - self.register_buffer("inv_freq", inv_freq, persistent=False) - - t = torch.arange(seq_len * self.scaling_factor, - device=device, - dtype=torch.float32) - - freqs = torch.outer(t, inv_freq) - cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale - sin_cached = torch.cat([freqs, freqs], dim=-1).sin() * self.mscale - cos_cached = cos_cached.to(dtype) - sin_cached = sin_cached.to(dtype) - cache = torch.cat([freqs.cos() * self.mscale, - freqs.sin() * self.mscale], - dim=-1).to(dtype) - self.register_buffer("cos_sin_cache", cache, persistent=False) - self.register_buffer("cos_cached", cos_cached, persistent=False) - self.register_buffer("sin_cached", sin_cached, persistent=False) - - -def deepseek_rope_init_func( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: int, - is_neox_style: bool, - scaling_factor: float, - dtype: torch.dtype, - *, - extrapolation_factor: float = 1, - attn_factor: float = 1, - beta_fast: int = 32, - beta_slow: int = 1, - mscale: float = 1, - mscale_all_dim: float = 0, -) -> None: - self.scaling_factor = scaling_factor - self.extrapolation_factor = extrapolation_factor - self.attn_factor = attn_factor - self.beta_fast = beta_fast - self.beta_slow = beta_slow - # Get n-d magnitude scaling corrected for interpolation. - self.mscale = float( - yarn_get_mscale(self.scaling_factor, float(mscale)) / - yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * - attn_factor) - super(DeepseekScalingRotaryEmbedding, - self).__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) - self.max_seq_len = max_position_embeddings - _set_cos_sin_cache(self, - max_position_embeddings, - dtype=dtype, - device="npu") - - -RotaryEmbedding.forward_oot = rope_forward_oot - -# Note: we adopt the native huggingface deepseek rope initialization code from -# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for -# its more ascend compute friendly -DeepseekScalingRotaryEmbedding.__init__ = deepseek_rope_init_func -DeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward +@RotaryEmbedding.register_oot +class AscendRotaryEmbedding(RotaryEmbedding): + + def forward_oot( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + is_neox_style_override: Optional[bool] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + return rope_forward_oot(self, positions, query, key, offsets, + is_neox_style_override) + + +@DeepseekScalingRotaryEmbedding.register_oot +class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, + ) -> None: + + # Note: we adopt the native huggingface deepseek rope initialization code from + # https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for + # its more ascend compute friendly + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + self._yarn_get_mscale(self.scaling_factor, float(mscale)) / + self._yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * + attn_factor) + super(DeepseekScalingRotaryEmbedding, + self).__init__(head_size, rotary_dim, max_position_embeddings, + base, is_neox_style, dtype) + self.max_seq_len = max_position_embeddings + self._set_cos_sin_cache(max_position_embeddings, + dtype=dtype, + device="npu") + + def _yarn_get_mscale(self, scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + def _yarn_linear_ramp_mask(self, min_value, max_value, dim): + # Note: The if conditional branch is not used here + # to solve MTP compilation error. + max_value += (min_value == max_value).float() * 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32) - + min_value) / (max_value - min_value) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + # Inverse dim formula to find dim based on number of rotations + def _yarn_find_correction_dim(self, + num_rotations, + dim, + base=10000, + max_position_embeddings=2048): + # Note: use torch instead of math to solve MTP compilation error. + return (dim * torch.log( + torch.tensor(max_position_embeddings) / + (num_rotations * 2 * torch.pi))) / (2 * + torch.log(torch.tensor(base))) + + # Find dim range bounds based on rotations + def _yarn_find_correction_range(self, + low_rot, + high_rot, + dim, + base=10000, + max_position_embeddings=2048): + # Note: use torch instead of math to solve MTP compilation error. + low = torch.floor( + self._yarn_find_correction_dim(low_rot, dim, base, + max_position_embeddings)) + high = torch.ceil( + self._yarn_find_correction_dim(high_rot, dim, base, + max_position_embeddings)) + # Note: use torch instead of max/min to solve MTP compilation error. + return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.rotary_dim + + freq_extra = 1.0 / (self.base**( + torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + freq_inter = 1.0 / (self.scaling_factor * self.base**( + torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + + low, high = self._yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.max_position_embeddings, + ) + inv_freq_mask = 1.0 - self._yarn_linear_ramp_mask( + low, high, dim // 2).to(device=device, dtype=torch.float32) + inv_freq = freq_inter * (1 - + inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len * self.scaling_factor, + device=device, + dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale + sin_cached = torch.cat([freqs, freqs], dim=-1).sin() * self.mscale + cos_cached = cos_cached.to(dtype) + sin_cached = sin_cached.to(dtype) + cache = torch.cat( + [freqs.cos() * self.mscale, + freqs.sin() * self.mscale], dim=-1).to(dtype) + self.register_buffer("cos_sin_cache", cache, persistent=False) + self.register_buffer("cos_cached", cos_cached, persistent=False) + self.register_buffer("sin_cached", sin_cached, persistent=False) + + def forward(self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + max_seq_len: Optional[int] = None): + if max_seq_len is not None and max_seq_len > self.max_seq_len: + self._set_cos_sin_cache(max_seq_len, query.device, query.dtype) + if len(key.shape) == 2: + key = key[:, None, :] + # Note: we implement the non neox_style method with shuffle the last dim and neox style + # calculation method which is also more compute friendly to the ascend machine + # https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py + neox_style = True + if self.is_neox_style is False: + b, h_q, d = query.shape + query = query.view(b, h_q, d // 2, + 2).transpose(3, 2).reshape(b, h_q, d) + b, h_k, d = key.shape + key = key.view(b, h_k, d // 2, 2).transpose(3, + 2).reshape(b, h_k, d) + q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets, + neox_style) + return q_pe, k_pe diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 7c7ee58033..08eebd1801 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -37,7 +37,7 @@ from vllm.model_executor.parameter import PerTensorScaleParameter from vllm.model_executor.utils import set_weight_attrs -from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod +from vllm_ascend.ops.fused_moe import AscendDSUnquantizedFusedMoEMethod from vllm_ascend.utils import ASCEND_QUATIZATION_METHOD from .quantizer import AscendQuantizer @@ -104,7 +104,7 @@ def get_quant_method(self, layer: torch.nn.Module, elif isinstance(layer, FusedMoE): if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): - return AscendUnquantizedFusedMoEMethod() + return AscendDSUnquantizedFusedMoEMethod() return AscendFusedMoEMethod(self, prefix, self.packed_modules_mapping) return None