Skip to content

Commit a171e32

Browse files
committed
amend
1 parent 3fe6549 commit a171e32

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

sota-implementations/grpo/grpo_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -548,9 +548,9 @@ def make_env(cfg: DictConfig, devices: list[int] | None = None):
548548
AddThinkingPrompt(
549549
cond=lambda td: td["reward"] <= reward_threshold
550550
and td["step_count"] < max_steps,
551-
role="assistant",
552-
edit_last_turn=True,
553-
zero_reward=True,
551+
role="user",
552+
edit_last_turn=False,
553+
zero_reward=False,
554554
undo_done=True,
555555
),
556556
)

torchrl/envs/llm/datasets/ifeval.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,7 @@
77
from typing import Any, Callable, Literal
88

99
import torch
10-
from tensordict import (
11-
NonTensorData,
12-
NonTensorStack,
13-
TensorClass,
14-
TensorDict,
15-
)
10+
from tensordict import NonTensorData, NonTensorStack, TensorClass, TensorDict
1611
from torchrl._utils import logger as torchrl_logger
1712
from torchrl.data import Composite, NonTensor, Unbounded
1813
from torchrl.envs import StepCounter
@@ -72,16 +67,23 @@ def _collate_fn(batch):
7267
batch = torch.stack([TensorDict.from_any(_batch) for _batch in batch])
7368
batch.rename_key_("prompt", "query")
7469
# we want instruction_id_list and kwargs to be lists, but not NonTensorStacks
75-
instruction_id_list = batch.get("instruction_id_list")
70+
instruction_id_list = batch["instruction_id_list"]
7671
# instruction_id_list should be a list of lists
7772
instruction_id_list = NonTensorStack(
78-
*[NonTensorData([item] if not isinstance(item, list) else item) for item in instruction_id_list]
73+
*[
74+
NonTensorData([item] if not isinstance(item, list) else item)
75+
for item in instruction_id_list
76+
]
77+
)
78+
kwargs = batch["kwargs"]
79+
kwargs = NonTensorStack(
80+
*[
81+
NonTensorData([item] if not isinstance(item, list) else item)
82+
for item in kwargs
83+
]
7984
)
80-
kwargs = batch.get("kwargs")
81-
kwargs = NonTensorStack(*[NonTensorData([item] if not isinstance(item, dict) else item) for item in kwargs])
8285
batch.set("instruction_id_list", instruction_id_list)
8386
batch.set("kwargs", kwargs)
84-
torchrl_logger.info(f"Collated batch: {batch}")
8587
# we don't need a tensorclass here
8688
return batch
8789
# return IFEvalData.from_tensordict(batch)

0 commit comments

Comments
 (0)