Skip to content

Commit 185bff9

Browse files
committed
amend
1 parent 0b658bf commit 185bff9

File tree

2 files changed

+22
-12
lines changed

2 files changed

+22
-12
lines changed

sota-implementations/grpo/grpo_utils.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torchrl.collectors.llm.weight_update.vllm import vLLMUpdater
1515
from torchrl.envs.llm import AddThinkingPrompt, GSM8KEnv, KLRewardTransform, RetrieveKL
1616
from torchrl.envs.llm.datasets.ifeval import IFEvalEnv
17+
from torchrl.envs.llm.transforms.enhanced_reasoning import EnhancedReasoningTransform
1718
from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
1819
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
1920
from transformers.tokenization_utils import PreTrainedTokenizer
@@ -545,15 +546,24 @@ def make_env(cfg: DictConfig, devices: list[int] | None = None):
545546
raise NotImplementedError(f"Dataset {cfg.env.dataset} not implemented")
546547
if cfg.env.reasoning:
547548
env = env.append_transform(
548-
AddThinkingPrompt(
549-
cond=lambda td: td["reward"] <= reward_threshold
550-
and td["step_count"] < max_steps,
551-
role="assistant",
552-
edit_last_turn=True,
553-
zero_reward=True,
554-
undo_done=True,
555-
random_prompt=True,
556-
),
549+
# AddThinkingPrompt(
550+
# cond=lambda td: td["reward"] <= reward_threshold
551+
# and td["step_count"] < max_steps,
552+
# role="assistant",
553+
# edit_last_turn=True,
554+
# zero_reward=True,
555+
# undo_done=True,
556+
# random_prompt=True,
557+
# ),
558+
EnhancedReasoningTransform(
559+
cond=lambda td: td["reward"] <= 1.0 and td["step_count"] < 3,
560+
strategy="user_guidance", # User tells assistant to reconsider
561+
reward_threshold=1.0,
562+
max_steps=3,
563+
zero_reward=True,
564+
undo_done=True,
565+
random_prompt=True,
566+
)
557567
)
558568
env = env.append_transform(
559569
# RetrieveKL will be lazily initialized in the collector.

torchrl/envs/llm/transforms/reason.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,17 +251,17 @@ def _replace_answer_with_prompt(self, content: str) -> str:
251251

252252
# Clean up any trailing whitespace
253253
modified_content = modified_content.rstrip()
254-
254+
255255
# Ensure we end with the EOS token if the original content had it
256256
if content.endswith("<|im_end|>"):
257257
modified_content = modified_content.rstrip() + "<|im_end|>"
258-
258+
259259
# Ensure proper spacing around the prompt
260260
if not modified_content.endswith(prompt):
261261
# If the prompt wasn't properly inserted, append it
262262
modified_content = content.rstrip()
263263
if modified_content.endswith("<|im_end|>"):
264-
modified_content = modified_content[:-len("<|im_end|>")].rstrip()
264+
modified_content = modified_content[: -len("<|im_end|>")].rstrip()
265265
modified_content = modified_content + "\n\n" + prompt + "<|im_end|>"
266266

267267
else:

0 commit comments

Comments
 (0)