We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 97e3334 commit 2e8e4eeCopy full SHA for 2e8e4ee
verl/models/transformers/flash_attention_utils.py
@@ -50,7 +50,7 @@ def prepare_fa2_from_position_ids(
50
position_ids = position_ids.view(-1)
51
cu_seqlens = torch.cat(
52
(
53
- (position_ids == 0).nonzero().view(-1),
+ (position_ids == 0).nonzero().view(-1).to(torch.int32),
54
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
55
)
56
0 commit comments