|
7 | 7 | from typing import Any, Callable, Literal
|
8 | 8 |
|
9 | 9 | import torch
|
10 |
| -from tensordict import ( |
11 |
| - NonTensorData, |
12 |
| - NonTensorStack, |
13 |
| - TensorClass, |
14 |
| - TensorDict, |
15 |
| -) |
| 10 | +from tensordict import NonTensorData, NonTensorStack, TensorClass, TensorDict |
16 | 11 | from torchrl._utils import logger as torchrl_logger
|
17 | 12 | from torchrl.data import Composite, NonTensor, Unbounded
|
18 | 13 | from torchrl.envs import StepCounter
|
@@ -72,16 +67,23 @@ def _collate_fn(batch):
|
72 | 67 | batch = torch.stack([TensorDict.from_any(_batch) for _batch in batch])
|
73 | 68 | batch.rename_key_("prompt", "query")
|
74 | 69 | # we want instruction_id_list and kwargs to be lists, but not NonTensorStacks
|
75 |
| - instruction_id_list = batch.get("instruction_id_list") |
| 70 | + instruction_id_list = batch["instruction_id_list"] |
76 | 71 | # instruction_id_list should be a list of lists
|
77 | 72 | instruction_id_list = NonTensorStack(
|
78 |
| - *[NonTensorData([item] if not isinstance(item, list) else item) for item in instruction_id_list] |
| 73 | + *[ |
| 74 | + NonTensorData([item] if not isinstance(item, list) else item) |
| 75 | + for item in instruction_id_list |
| 76 | + ] |
| 77 | + ) |
| 78 | + kwargs = batch["kwargs"] |
| 79 | + kwargs = NonTensorStack( |
| 80 | + *[ |
| 81 | + NonTensorData([item] if not isinstance(item, list) else item) |
| 82 | + for item in kwargs |
| 83 | + ] |
79 | 84 | )
|
80 |
| - kwargs = batch.get("kwargs") |
81 |
| - kwargs = NonTensorStack(*[NonTensorData([item] if not isinstance(item, dict) else item) for item in kwargs]) |
82 | 85 | batch.set("instruction_id_list", instruction_id_list)
|
83 | 86 | batch.set("kwargs", kwargs)
|
84 |
| - torchrl_logger.info(f"Collated batch: {batch}") |
85 | 87 | # we don't need a tensorclass here
|
86 | 88 | return batch
|
87 | 89 | # return IFEvalData.from_tensordict(batch)
|
|
0 commit comments