Skip to content

Commit c24184d

Browse files
committed
fix
1 parent 25c0e6b commit c24184d

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

mindone/transformers/generation/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1942,7 +1942,7 @@ def _prepare_generation_config(
19421942

19431943
return generation_config, model_kwargs
19441944

1945-
def _get_initial_cache_position(self, input_ids, model_kwargs):
1945+
def _get_initial_cache_position(self, seq_length, model_kwargs):
19461946
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
19471947
if "cache_position" in model_kwargs and model_kwargs["cache_position"]:
19481948
return model_kwargs
@@ -1951,10 +1951,10 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):
19511951
cache_position = mint.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=ms.int32).cumsum(0) - 1
19521952
elif "decoder_inputs_embeds" in model_kwargs and self.config.is_encoder_decoder:
19531953
cache_position = (
1954-
mint.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=ms.int32).cumsum(0) - 1
1954+
mint.ones(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=ms.int32).cumsum(0) - 1
19551955
)
19561956
else:
1957-
cache_position = mint.ones_like(input_ids[0, :], dtype=ms.int32).cumsum(0) - 1
1957+
cache_position = mint.ones(seq_length, dtype=ms.int32).cumsum(0) - 1
19581958

19591959
if model_kwargs.get("past_key_values") is not None:
19601960
cache = model_kwargs["past_key_values"]
@@ -3721,7 +3721,7 @@ def _sample(
37213721
batch_size, cur_len = input_ids.shape
37223722
this_peer_finished = False
37233723
unfinished_sequences = mint.ones(batch_size, dtype=ms.int32)
3724-
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
3724+
model_kwargs = self._get_initial_cache_position(cur_len, model_kwargs)
37253725

37263726
model_forward = self.__call__
37273727
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)

0 commit comments

Comments
 (0)