Skip to content

Commit 2e8e4ee

Browse files
authored
[misc] fix fa patch (#473)
1 parent 97e3334 commit 2e8e4ee

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

verl/models/transformers/flash_attention_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def prepare_fa2_from_position_ids(
5050
position_ids = position_ids.view(-1)
5151
cu_seqlens = torch.cat(
5252
(
53-
(position_ids == 0).nonzero().view(-1),
53+
(position_ids == 0).nonzero().view(-1).to(torch.int32),
5454
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
5555
)
5656
)

0 commit comments

Comments
 (0)