File tree Expand file tree Collapse file tree 1 file changed +8
-6
lines changed Expand file tree Collapse file tree 1 file changed +8
-6
lines changed Original file line number Diff line number Diff line change @@ -240,17 +240,19 @@ def set_lora(
240
240
def forward (self , x : torch .Tensor ) -> torch .Tensor :
241
241
added_tokens_mask = torch .where (x > self .base_layer .org_vocab_size - 1 ,
242
242
1 , 0 )
243
- embeddings_indices = torch .narrow (
244
- self .punica_wrapper ._embeddings_indices , 1 , 0 , x .size (0 ))
245
243
246
- indices = embeddings_indices [1 ]
244
+ # NB: Don't use torch.narrow here. torch.narrow triggers some
245
+ # Dynamic Shape specialization in torch.compile
246
+ num_tokens = x .shape [0 ]
247
+ indices_1 = self .punica_wrapper ._embeddings_indices [1 ][:num_tokens ]
248
+ indices_0 = self .punica_wrapper ._embeddings_indices [0 ][:num_tokens ]
249
+
247
250
full_lora_a_embeddings = F .embedding (
248
- x + indices ,
251
+ x + indices_1 ,
249
252
self .lora_a_stacked_2d ,
250
253
)
251
- indices = embeddings_indices [0 ]
252
254
full_output = self .base_layer .forward (x +
253
- (indices * added_tokens_mask ))
255
+ (indices_0 * added_tokens_mask ))
254
256
255
257
full_output_org = full_output
256
258
if full_output .ndim == 3 :
You can’t perform that action at this time.
0 commit comments