Skip to content

Commit 5ab4bd3

Browse files
authored
[data] better mm data collate (#424)
1 parent 75b07d2 commit 5ab4bd3

File tree

2 files changed

+10
-21
lines changed

2 files changed

+10
-21
lines changed

verl/workers/actor/dp_actor.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from torch import nn
2727
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
2828

29-
from ...protocol import DataProto
29+
from ...protocol import DataProto, batch_collate
3030
from ...trainer.core_algos import average_loss, compute_kl, compute_policy_loss
3131
from ...utils import torch_functional as VF
3232
from ...utils.py_functional import append_to_dict
@@ -81,15 +81,10 @@ def _forward_micro_batch(self, micro_batch: Dict[str, torch.Tensor], temperature
8181

8282
multi_modal_inputs = defaultdict(list)
8383
if "multi_modal_inputs" in micro_batch:
84-
for input_dict in micro_batch["multi_modal_inputs"]:
85-
for key, value in input_dict.items():
86-
multi_modal_inputs[key].append(value)
87-
88-
for key, value in multi_modal_inputs.items():
89-
if len(value) != 0:
90-
multi_modal_inputs[key] = torch.cat(value, dim=0)
91-
else:
92-
multi_modal_inputs[key] = None
84+
multi_modal_inputs = batch_collate(micro_batch["multi_modal_inputs"])
85+
multi_modal_inputs = {key: torch.cat(value, dim=0) for key, value in multi_modal_inputs.items()}
86+
else:
87+
multi_modal_inputs = {}
9388

9489
if self.config.padding_free:
9590
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) # (total_nnz, 1)

verl/workers/critic/dp_critic.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from torch import nn
2626
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
2727

28-
from ...protocol import DataProto
28+
from ...protocol import DataProto, batch_collate
2929
from ...trainer.core_algos import compute_value_loss
3030
from ...utils.py_functional import append_to_dict
3131
from ...utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch
@@ -61,17 +61,11 @@ def _forward_micro_batch(self, micro_batch: Dict[str, torch.Tensor]) -> torch.Te
6161
if position_ids.dim() == 3: # qwen2vl mrope
6262
position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen)
6363

64-
multi_modal_inputs = defaultdict(list)
6564
if "multi_modal_inputs" in micro_batch:
66-
for input_dict in micro_batch["multi_modal_inputs"]:
67-
for key, value in input_dict.items():
68-
multi_modal_inputs[key].append(value)
69-
70-
for key, value in multi_modal_inputs.items():
71-
if len(value) != 0:
72-
multi_modal_inputs[key] = torch.cat(value, dim=0)
73-
else:
74-
multi_modal_inputs[key] = None
65+
multi_modal_inputs = batch_collate(micro_batch["multi_modal_inputs"])
66+
multi_modal_inputs = {key: torch.cat(value, dim=0) for key, value in multi_modal_inputs.items()}
67+
else:
68+
multi_modal_inputs = {}
7569

7670
if self.config.padding_free:
7771
input_ids_rmpad, indices, *_ = unpad_input(

0 commit comments

Comments
 (0)