diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py index a8ef77bc1..42418a212 100644 --- a/internlm/model/modules/mha.py +++ b/internlm/model/modules/mha.py @@ -211,7 +211,7 @@ def _training(self, x, **kwargs): # self attention kwargs = _convert_cu_seqlens_for_qksplited(kwargs) - if gpc.config.data.use_packed_dataset is False: + if gpc.config.data.use_packed_dataset is False or self.training is False: kwargs.pop("max_seqlen_q", None) kwargs.pop("max_seqlen_k", None) context = self.inner_attn(q, k, v, **kwargs) @@ -529,7 +529,7 @@ def _training(self, x, **kwargs): kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) - if gpc.config.data.use_packed_dataset is False: + if gpc.config.data.use_packed_dataset is False or self.training is False: kwargs.pop("max_seqlen_q", None) kwargs.pop("max_seqlen_k", None)