22
22
from vllm .model_executor .layers .rotary_embedding import (
23
23
DeepseekScalingRotaryEmbedding , RotaryEmbedding )
24
24
25
+ import vllm_ascend .envs as ascend_envs
25
26
from vllm_ascend .platform import CUSTOM_OP_ENABLED
26
27
27
28
@@ -75,6 +76,52 @@ def rope_forward_oot(
75
76
return query .view (query_shape ), key .view (key_shape )
76
77
77
78
79
+ def rope_forward_oot_npu_mrope (
80
+ self ,
81
+ positions : torch .Tensor ,
82
+ query : torch .Tensor ,
83
+ key : torch .Tensor ,
84
+ offsets : Optional [torch .Tensor ] = None ,
85
+ is_neox_style_override : Optional [bool ] = None
86
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
87
+ import torch_npu
88
+ query_shape , key_shape = query .shape , key .shape
89
+ if self .cos_sin_cache .device != query .device :
90
+ self .cos_sin_cache = self .cos_sin_cache .to (query .device )
91
+ if self .cos_sin_cache .dtype != query .dtype :
92
+ self .cos_sin_cache = self .cos_sin_cache .to (query .dtype )
93
+ neox_style = self .is_neox_style
94
+ if is_neox_style_override is not None :
95
+ neox_style = is_neox_style_override
96
+ # adopt custom kernel path for rotary_embedding
97
+ if custom_rotary_embedding_enabled (query , neox_style , self .head_size ):
98
+ query , key = torch .ops ._C .rotary_embedding (
99
+ positions ,
100
+ query ,
101
+ key ,
102
+ self .head_size ,
103
+ self .cos_sin_cache ,
104
+ neox_style ,
105
+ )
106
+ return query .view (query_shape ), key .view (key_shape )
107
+ if offsets is not None :
108
+ raise NotImplementedError (
109
+ "Batched rotary embedding is currently not supported on NPU." )
110
+ else :
111
+ # TODO: Remove the contiguous in the future.
112
+ query = query .contiguous ().view (query .shape [0 ], - 1 )
113
+ key = key .contiguous ().view (key .shape [0 ], - 1 )
114
+ query , key = torch_npu .npu_mrope (
115
+ positions ,
116
+ query ,
117
+ key ,
118
+ self .cos_sin_cache ,
119
+ self .head_size ,
120
+ mrope_section = [0 , 0 , 0 ],
121
+ rotary_mode = "half" if neox_style else "interleave" )
122
+ return query .view (query_shape ), key .view (key_shape )
123
+
124
+
78
125
def native_rope_deepseek_forward (self ,
79
126
positions : torch .Tensor ,
80
127
query : torch .Tensor ,
@@ -95,8 +142,8 @@ def native_rope_deepseek_forward(self,
95
142
2 ).reshape (b , h_q , d )
96
143
b , h_k , d = key .shape
97
144
key = key .view (b , h_k , d // 2 , 2 ).transpose (3 , 2 ).reshape (b , h_k , d )
98
- q_pe , k_pe = rope_forward_oot (self , positions , query , key , offsets ,
99
- neox_style )
145
+ q_pe , k_pe = RotaryEmbedding . forward_oot (self , positions , query , key ,
146
+ offsets , neox_style )
100
147
return q_pe , k_pe
101
148
102
149
@@ -270,7 +317,10 @@ def deepseek_rope_init_func(
270
317
device = "npu" )
271
318
272
319
273
- RotaryEmbedding .forward_oot = rope_forward_oot
320
+ if not ascend_envs .VLLM_ASCEND_ENABLE_NPU_MROPE :
321
+ RotaryEmbedding .forward_oot = rope_forward_oot
322
+ else :
323
+ RotaryEmbedding .forward_oot = rope_forward_oot_npu_mrope
274
324
275
325
# Note: we adopt the native huggingface deepseek rope initialization code from
276
326
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for
0 commit comments