Skip to content

Commit b823f3d

Browse files
committed
amend
1 parent 5aa8969 commit b823f3d

File tree

6 files changed

+151
-126
lines changed

6 files changed

+151
-126
lines changed

sota-implementations/grpo/grpo_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from torchrl._utils import logger as torchrl_logger
1414
from torchrl.collectors.llm.weight_update.vllm import vLLMUpdater
15-
from torchrl.envs.llm import GSM8KEnv, KLRewardTransform, RetrieveKL, AddThinkingPrompt
15+
from torchrl.envs.llm import AddThinkingPrompt, GSM8KEnv, KLRewardTransform, RetrieveKL
1616
from torchrl.envs.llm.datasets.ifeval import IFEvalEnv
1717
from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
1818
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
@@ -523,7 +523,7 @@ def make_env(cfg: DictConfig, devices: list[int] | None = None):
523523
max_steps = cfg.env.max_steps if cfg.env.reasoning else 1
524524
if cfg.env.dataset == "gsm8k":
525525
# Reward scale is 0.0 to 100
526-
reward_threshold=20
526+
reward_threshold = 20
527527
env = GSM8KEnv(
528528
repeats=cfg.env.repeats,
529529
tokenizer=train_tokenizer,
@@ -533,7 +533,7 @@ def make_env(cfg: DictConfig, devices: list[int] | None = None):
533533
)
534534
elif cfg.env.dataset == "ifeval": # ifeval
535535
# Reward scale is 0.0 to 2.2
536-
reward_threshold=1.0
536+
reward_threshold = 1.0
537537
env = IFEvalEnv(
538538
repeats=cfg.env.repeats,
539539
tokenizer=train_tokenizer,
@@ -546,7 +546,11 @@ def make_env(cfg: DictConfig, devices: list[int] | None = None):
546546
if cfg.env.reasoning:
547547
env = env.append_transform(
548548
AddThinkingPrompt(
549-
cond=lambda td, reward_threshol=reward_threshold, max_steps=max_steps: td["reward"] <= reward_threshold and td["step_count"] < max_steps,
549+
cond=lambda td, reward_threshol=reward_threshold, max_steps=max_steps: td[
550+
"reward"
551+
]
552+
<= reward_threshold
553+
and td["step_count"] < max_steps,
550554
role="assistant",
551555
edit_last_turn=True,
552556
zero_reward=True,

0 commit comments

Comments
 (0)