|
7 | 7 | from typing import Any, Callable, Literal
|
8 | 8 |
|
9 | 9 | import torch
|
10 |
| -from tensordict import lazy_stack, TensorClass, TensorDict, NonTensorStack, NonTensorData |
| 10 | +from tensordict import ( |
| 11 | + NonTensorData, |
| 12 | + NonTensorStack, |
| 13 | + TensorClass, |
| 14 | + TensorDict, |
| 15 | +) |
11 | 16 | from torchrl._utils import logger as torchrl_logger
|
12 | 17 | from torchrl.data import Composite, NonTensor, Unbounded
|
13 | 18 | from torchrl.envs import StepCounter
|
@@ -66,12 +71,14 @@ def default_spec(
|
66 | 71 | def _collate_fn(batch):
|
67 | 72 | batch = torch.stack([TensorDict.from_any(_batch) for _batch in batch])
|
68 | 73 | batch.rename_key_("prompt", "query")
|
69 |
| - # we want instruction_id_list and kwargs to be lists, but not NonTensorStacks |
| 74 | + # we want instruction_id_list and kwargs to be lists, but not NonTensorStacks |
70 | 75 | instruction_id_list = batch.get("instruction_id_list")
|
71 |
| - # instruction_id_list should be a list of lists |
72 |
| - instruction_id_list = NonTensorStack(*[NonTensorData(item) for item in instruction_id_list]) |
| 76 | + # instruction_id_list should be a list of lists |
| 77 | + instruction_id_list = NonTensorStack( |
| 78 | + *[NonTensorData([item] if not isinstance(item, list) else item) for item in instruction_id_list] |
| 79 | + ) |
73 | 80 | kwargs = batch.get("kwargs")
|
74 |
| - kwargs = NonTensorStack(*[NonTensorData(item) for item in kwargs]) |
| 81 | + kwargs = NonTensorStack(*[NonTensorData([item] if not isinstance(item, dict) else item) for item in kwargs]) |
75 | 82 | batch.set("instruction_id_list", instruction_id_list)
|
76 | 83 | batch.set("kwargs", kwargs)
|
77 | 84 | torchrl_logger.info(f"Collated batch: {batch}")
|
|
0 commit comments