Skip to content

Commit 3fe6549

Browse files
committed
amend
1 parent 82b75ce commit 3fe6549

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

torchrl/envs/llm/datasets/ifeval.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
from typing import Any, Callable, Literal
88

99
import torch
10-
from tensordict import lazy_stack, TensorClass, TensorDict, NonTensorStack, NonTensorData
10+
from tensordict import (
11+
NonTensorData,
12+
NonTensorStack,
13+
TensorClass,
14+
TensorDict,
15+
)
1116
from torchrl._utils import logger as torchrl_logger
1217
from torchrl.data import Composite, NonTensor, Unbounded
1318
from torchrl.envs import StepCounter
@@ -66,12 +71,14 @@ def default_spec(
6671
def _collate_fn(batch):
6772
batch = torch.stack([TensorDict.from_any(_batch) for _batch in batch])
6873
batch.rename_key_("prompt", "query")
69-
# we want instruction_id_list and kwargs to be lists, but not NonTensorStacks
74+
# we want instruction_id_list and kwargs to be lists, but not NonTensorStacks
7075
instruction_id_list = batch.get("instruction_id_list")
71-
# instruction_id_list should be a list of lists
72-
instruction_id_list = NonTensorStack(*[NonTensorData(item) for item in instruction_id_list])
76+
# instruction_id_list should be a list of lists
77+
instruction_id_list = NonTensorStack(
78+
*[NonTensorData([item] if not isinstance(item, list) else item) for item in instruction_id_list]
79+
)
7380
kwargs = batch.get("kwargs")
74-
kwargs = NonTensorStack(*[NonTensorData(item) for item in kwargs])
81+
kwargs = NonTensorStack(*[NonTensorData([item] if not isinstance(item, dict) else item) for item in kwargs])
7582
batch.set("instruction_id_list", instruction_id_list)
7683
batch.set("kwargs", kwargs)
7784
torchrl_logger.info(f"Collated batch: {batch}")

0 commit comments

Comments
 (0)