Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions internlm/model/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,11 @@ def _update_cos_sin_cache(
if max_seqlen is not None:
seqlen = max_seqlen
elif isinstance(indexes, int):
# logic changed temporaryly
# seqlen = indexes + x.shape[1] + 1
seqlen = gpc.config.data.seq_len
seqlen = indexes + x.shape[1]
else:
# Note that this statement may cause synchronization between CPU and GPU,
# so it's best to precompute and pass in max_seqlen ahead of time
seqlen = indexes.max().item() + 1
seqlen = indexes.max().item()

# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
Expand Down Expand Up @@ -219,9 +217,9 @@ def __init__(
def _update_cos_sin_cache(self, x, indexes):
"""x: (batch, seqlen, nheads, headdim)"""
if not isinstance(indexes, int):
seqlen = indexes.max().item() + 1
seqlen = indexes.max().item()
else:
seqlen = indexes + x.shape[1] + 1
seqlen = indexes + x.shape[1]

t = torch.arange(seqlen, device=x.device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor
Expand Down Expand Up @@ -286,9 +284,9 @@ def _update(self, seqlen, x):
def _update_cos_sin_cache(self, x, indexes):
"""x: (batch, seqlen, nheads, headdim)"""
if not isinstance(indexes, int):
seqlen = indexes.max().item() + 1
seqlen = indexes.max().item()
else:
seqlen = indexes + x.shape[1] + 1 # eval_forward
seqlen = indexes + x.shape[1] # eval_forward
if seqlen <= self.max_position_embeddings:
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
Expand Down