From b3608c4b80fd8b121fa3268924fca17799d217b4 Mon Sep 17 00:00:00 2001 From: wimh966 Date: Sat, 13 Jan 2024 13:57:39 +0800 Subject: [PATCH] optimize the generation of attention mask --- megatron/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/megatron/utils.py b/megatron/utils.py index 24e1888f79..800dc6705b 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -178,7 +178,7 @@ def get_ltor_masks_and_position_ids(data, attention_mask = None if not skip_mask: attention_mask = torch.tril(torch.ones( - (att_mask_batch, seq_length, seq_length))).view(att_mask_batch, 1, seq_length, seq_length) + (att_mask_batch, seq_length, seq_length), device=data.device)).view(att_mask_batch, 1, seq_length, seq_length) # Loss mask. loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) @@ -218,7 +218,6 @@ def get_ltor_masks_and_position_ids(data, # Convert attention mask to binary: if not skip_mask: attention_mask = (attention_mask < 0.5) - attention_mask = attention_mask.to(data.device) return attention_mask, loss_mask, position_ids