|  | 
| 25 | 25 | from torch import nn | 
| 26 | 26 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | 
| 27 | 27 | 
 | 
| 28 |  | -from ...protocol import DataProto | 
|  | 28 | +from ...protocol import DataProto, batch_collate | 
| 29 | 29 | from ...trainer.core_algos import compute_value_loss | 
| 30 | 30 | from ...utils.py_functional import append_to_dict | 
| 31 | 31 | from ...utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch | 
| @@ -61,17 +61,11 @@ def _forward_micro_batch(self, micro_batch: Dict[str, torch.Tensor]) -> torch.Te | 
| 61 | 61 |         if position_ids.dim() == 3:  # qwen2vl mrope | 
| 62 | 62 |             position_ids = position_ids.transpose(0, 1)  # (bsz, 3, seqlen) -> (3, bsz, seqlen) | 
| 63 | 63 | 
 | 
| 64 |  | -        multi_modal_inputs = defaultdict(list) | 
| 65 | 64 |         if "multi_modal_inputs" in micro_batch: | 
| 66 |  | -            for input_dict in micro_batch["multi_modal_inputs"]: | 
| 67 |  | -                for key, value in input_dict.items(): | 
| 68 |  | -                    multi_modal_inputs[key].append(value) | 
| 69 |  | - | 
| 70 |  | -            for key, value in multi_modal_inputs.items(): | 
| 71 |  | -                if len(value) != 0: | 
| 72 |  | -                    multi_modal_inputs[key] = torch.cat(value, dim=0) | 
| 73 |  | -                else: | 
| 74 |  | -                    multi_modal_inputs[key] = None | 
|  | 65 | +            multi_modal_inputs = batch_collate(micro_batch["multi_modal_inputs"]) | 
|  | 66 | +            multi_modal_inputs = {key: torch.cat(value, dim=0) for key, value in multi_modal_inputs.items()} | 
|  | 67 | +        else: | 
|  | 68 | +            multi_modal_inputs = {} | 
| 75 | 69 | 
 | 
| 76 | 70 |         if self.config.padding_free: | 
| 77 | 71 |             input_ids_rmpad, indices, *_ = unpad_input( | 
|  | 
0 commit comments