@@ -1942,7 +1942,7 @@ def _prepare_generation_config(
19421942
19431943 return generation_config , model_kwargs
19441944
1945- def _get_initial_cache_position (self , input_ids , model_kwargs ):
1945+ def _get_initial_cache_position (self , seq_length , model_kwargs ):
19461946 """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
19471947 if "cache_position" in model_kwargs and model_kwargs ["cache_position" ]:
19481948 return model_kwargs
@@ -1951,10 +1951,10 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):
19511951 cache_position = mint .ones_like (model_kwargs ["inputs_embeds" ][0 , :, 0 ], dtype = ms .int32 ).cumsum (0 ) - 1
19521952 elif "decoder_inputs_embeds" in model_kwargs and self .config .is_encoder_decoder :
19531953 cache_position = (
1954- mint .ones_like (model_kwargs ["decoder_inputs_embeds" ][0 , :, 0 ], dtype = ms .int32 ).cumsum (0 ) - 1
1954+ mint .ones (model_kwargs ["decoder_inputs_embeds" ][0 , :, 0 ], dtype = ms .int32 ).cumsum (0 ) - 1
19551955 )
19561956 else :
1957- cache_position = mint .ones_like ( input_ids [ 0 , :] , dtype = ms .int32 ).cumsum (0 ) - 1
1957+ cache_position = mint .ones ( seq_length , dtype = ms .int32 ).cumsum (0 ) - 1
19581958
19591959 if model_kwargs .get ("past_key_values" ) is not None :
19601960 cache = model_kwargs ["past_key_values" ]
@@ -3721,7 +3721,7 @@ def _sample(
37213721 batch_size , cur_len = input_ids .shape
37223722 this_peer_finished = False
37233723 unfinished_sequences = mint .ones (batch_size , dtype = ms .int32 )
3724- model_kwargs = self ._get_initial_cache_position (input_ids , model_kwargs )
3724+ model_kwargs = self ._get_initial_cache_position (cur_len , model_kwargs )
37253725
37263726 model_forward = self .__call__
37273727 compile_forward = self ._valid_auto_compile_criteria (model_kwargs , generation_config )
0 commit comments