@@ -137,38 +137,44 @@ def forward(
137
137
k : Tensor ,
138
138
positions : Tensor ,
139
139
):
140
- def _rope (x : te .Tensor , positions : te .Tensor ):
140
+ def _rope_fused (x : te .Tensor , positions : te .Tensor ):
141
+ _ , _ , _ , d_dim = x .shape
142
+ d_dim_half = d_dim // 2
141
143
dtype = x .dtype
142
144
143
145
def compute (b : tir .Var , s : tir .Var , h : tir .Var , d : tir .Var ):
146
+ d1 = d // d_dim_half
147
+ d2 = d % d_dim_half
148
+
144
149
cos_freq , sin_freq , var_map = self .rope_fn (
145
150
positions [s ], d , self .rotary_dim , self .theta , dtype
146
151
)
147
- cos = cos_freq * x [b , s , h , d ]
148
- sin = sin_freq * tir .if_then_else (
152
+ cos = x [b , s , h , d2 * 2 + d1 ] * cos_freq
153
+
154
+ partner_d = tir .if_then_else (
149
155
d < self .rotary_dim // 2 ,
150
- - x [b , s , h , d + self .rotary_dim // 2 ],
151
- x [b , s , h , d - self .rotary_dim // 2 ],
156
+ d + self .rotary_dim // 2 ,
157
+ d - self .rotary_dim // 2 ,
158
+ )
159
+
160
+ partner_d1 = partner_d // d_dim_half
161
+ partner_d2 = partner_d % d_dim_half
162
+ sin = (
163
+ x [b , s , h , partner_d2 * 2 + partner_d1 ]
164
+ * sin_freq
165
+ * tir .if_then_else (
166
+ d < self .rotary_dim // 2 , tir .const (- 1 , dtype ), tir .const (1 , dtype )
167
+ )
152
168
)
153
169
expr = cos + sin
154
- for var , value in var_map .items ():
155
- expr = tir .Let (var , value , expr )
170
+ for var , val in var_map .items ():
171
+ expr = tir .Let (var , val , expr )
156
172
return expr
157
173
158
174
return te .compute (x .shape , compute , name = "yarn_rope" )
159
175
160
- b , s , h , d = q .shape
161
- q = op .reshape (
162
- op .permute_dims (op .reshape (q , (b , s , h , d // 2 , 2 )), [0 , 1 , 2 , 4 , 3 ]), (b , s , h , d )
163
- )
164
-
165
- b , s , h , d = k .shape
166
- k = op .reshape (
167
- op .permute_dims (op .reshape (k , (b , s , h , d // 2 , 2 )), [0 , 1 , 2 , 4 , 3 ]), (b , s , h , d )
168
- )
169
-
170
- q_embed = op .tensor_expr_op (_rope , "rope" , [q , positions ])
171
- k_embed = op .tensor_expr_op (_rope , "rope" , [k , positions ])
176
+ q_embed = op .tensor_expr_op (_rope_fused , "rope" , [q , positions ])
177
+ k_embed = op .tensor_expr_op (_rope_fused , "rope" , [k , positions ])
172
178
return q_embed , k_embed
173
179
174
180
0 commit comments