Skip to content

Commit 58a890f

Browse files
committed
amend
1 parent 185bff9 commit 58a890f

File tree

2 files changed

+226
-12
lines changed

2 files changed

+226
-12
lines changed

sota-implementations/grpo/grpo_utils.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from torchrl._utils import logger as torchrl_logger
1414
from torchrl.collectors.llm.weight_update.vllm import vLLMUpdater
15-
from torchrl.envs.llm import AddThinkingPrompt, GSM8KEnv, KLRewardTransform, RetrieveKL
15+
from torchrl.envs.llm import GSM8KEnv, KLRewardTransform, RetrieveKL
1616
from torchrl.envs.llm.datasets.ifeval import IFEvalEnv
1717
from torchrl.envs.llm.transforms.enhanced_reasoning import EnhancedReasoningTransform
1818
from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
@@ -524,7 +524,6 @@ def make_env(cfg: DictConfig, devices: list[int] | None = None):
524524
max_steps = cfg.env.max_steps if cfg.env.reasoning else 1
525525
if cfg.env.dataset == "gsm8k":
526526
# Reward scale is 0.0 to 100
527-
reward_threshold = 20
528527
env = GSM8KEnv(
529528
repeats=cfg.env.repeats,
530529
tokenizer=train_tokenizer,
@@ -534,7 +533,6 @@ def make_env(cfg: DictConfig, devices: list[int] | None = None):
534533
)
535534
elif cfg.env.dataset == "ifeval": # ifeval
536535
# Reward scale is 0.0 to 2.2
537-
reward_threshold = 1.0
538536
env = IFEvalEnv(
539537
repeats=cfg.env.repeats,
540538
tokenizer=train_tokenizer,
@@ -555,15 +553,15 @@ def make_env(cfg: DictConfig, devices: list[int] | None = None):
555553
# undo_done=True,
556554
# random_prompt=True,
557555
# ),
558-
EnhancedReasoningTransform(
559-
cond=lambda td: td["reward"] <= 1.0 and td["step_count"] < 3,
560-
strategy="user_guidance", # User tells assistant to reconsider
561-
reward_threshold=1.0,
562-
max_steps=3,
563-
zero_reward=True,
564-
undo_done=True,
565-
random_prompt=True,
566-
)
556+
EnhancedReasoningTransform(
557+
cond=lambda td: td["reward"] <= 1.0 and td["step_count"] < 3,
558+
strategy="user_guidance", # User tells assistant to reconsider
559+
reward_threshold=1.0,
560+
max_steps=3,
561+
zero_reward=True,
562+
undo_done=True,
563+
random_prompt=True,
564+
)
567565
)
568566
env = env.append_transform(
569567
# RetrieveKL will be lazily initialized in the collector.
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from __future__ import annotations
7+
8+
import re
9+
from typing import Callable, Literal, Optional
10+
11+
from tensordict import lazy_stack, TensorDictBase
12+
from torchrl._utils import logger as torchrl_logger
13+
14+
from torchrl.data.llm.history import History
15+
from torchrl.envs import Transform
16+
from torchrl.envs.common import EnvBase
17+
18+
19+
class EnhancedReasoningTransform(Transform):
20+
"""An enhanced transform that adds intelligent prompts to improve IFEval response quality.
21+
22+
This transform analyzes the reward and response quality to add targeted prompts that help
23+
the LLM improve its reasoning and response format.
24+
25+
Args:
26+
cond (Callable[[TensorDictBase], bool]): Condition function that determines when to add prompts.
27+
strategy (Literal["user_guidance", "format_reminder", "quality_hint", "thinking", "step_by_step"]):
28+
The strategy to use for prompting.
29+
reward_threshold (float): Reward threshold for triggering the transform.
30+
max_steps (int): Maximum number of steps allowed.
31+
zero_reward (bool): Whether to zero out the reward when the prompt is added.
32+
undo_done (bool): Whether to undo the done flag when the prompt is added.
33+
"""
34+
35+
# Different prompt strategies for different scenarios
36+
PROMPT_STRATEGIES = {
37+
"user_guidance": [
38+
"I notice your response doesn't follow the required format. Please provide your thinking between <think> and </think> tags, and your final answer between <answer> and </answer> tags.",
39+
"Your response needs to be structured properly. First think through the problem in <think> tags, then give your answer in <answer> tags.",
40+
"Please reconsider your response. Remember to use <think> tags for your reasoning and <answer> tags for your final response.",
41+
],
42+
"format_reminder": [
43+
"Remember to use the correct format: <think>your reasoning</think><answer>your answer</answer>",
44+
"Please structure your response with <think> and <answer> tags as instructed.",
45+
"Your response should follow this format: <think>...</think><answer>...</answer>",
46+
],
47+
"quality_hint": [
48+
"Let me help you improve your response. Think about this more carefully and provide a better answer.",
49+
"Your response could be better. Take a moment to reconsider and provide a more thoughtful answer.",
50+
"I think you can do better. Please think through this more carefully.",
51+
],
52+
"thinking": [
53+
"But wait, let me think about this more carefully...",
54+
"Actually, let me reconsider this...",
55+
"Let me think about it step by step...",
56+
"Wait, I need to double-check my reasoning...",
57+
],
58+
"step_by_step": [
59+
"Let me break this down step by step and think more carefully...",
60+
"I should approach this systematically. Let me think through each part...",
61+
"Let me reconsider this by going through it step by step...",
62+
]
63+
}
64+
65+
def __init__(
66+
self,
67+
cond: Callable[[TensorDictBase], bool],
68+
strategy: Literal["user_guidance", "format_reminder", "quality_hint", "thinking", "step_by_step"] = "user_guidance",
69+
reward_threshold: float = 1.0,
70+
max_steps: int = 3,
71+
zero_reward: bool = True,
72+
undo_done: bool = True,
73+
random_prompt: bool = True,
74+
) -> None:
75+
super().__init__()
76+
77+
self.cond = cond
78+
self.strategy = strategy
79+
self.reward_threshold = reward_threshold
80+
self.max_steps = max_steps
81+
self.zero_reward = zero_reward
82+
self.undo_done = undo_done
83+
self.random_prompt = random_prompt
84+
85+
def _get_prompt(self) -> str:
86+
"""Get the appropriate prompt based on the strategy."""
87+
prompts = self.PROMPT_STRATEGIES[self.strategy]
88+
if self.random_prompt:
89+
import random
90+
return random.choice(prompts)
91+
return prompts[0]
92+
93+
def _analyze_response_quality(self, content: str) -> dict:
94+
"""Analyze the quality of the response to determine the best strategy."""
95+
analysis = {
96+
"has_think_tags": "<think>" in content and "</think>" in content,
97+
"has_answer_tags": "<answer>" in content and "</answer>" in content,
98+
"proper_format": self._check_proper_format(content),
99+
"malformed_tags": self._check_malformed_tags(content),
100+
"incomplete_response": len(content.strip()) < 50,
101+
}
102+
103+
# Determine the best strategy based on analysis
104+
if not analysis["has_think_tags"] or not analysis["has_answer_tags"]:
105+
analysis["recommended_strategy"] = "format_reminder"
106+
elif analysis["malformed_tags"]:
107+
analysis["recommended_strategy"] = "format_reminder"
108+
elif analysis["incomplete_response"]:
109+
analysis["recommended_strategy"] = "quality_hint"
110+
else:
111+
analysis["recommended_strategy"] = "thinking"
112+
113+
return analysis
114+
115+
def _check_proper_format(self, content: str) -> bool:
116+
"""Check if the response follows the proper IFEval format."""
117+
# Check for proper tag structure
118+
think_pattern = r"<think>.*?</think>"
119+
answer_pattern = r"<answer>.*?</answer>"
120+
121+
has_think = bool(re.search(think_pattern, content, re.DOTALL))
122+
has_answer = bool(re.search(answer_pattern, content, re.DOTALL))
123+
124+
return has_think and has_answer
125+
126+
def _check_malformed_tags(self, content: str) -> bool:
127+
"""Check for malformed tags with extra spaces or wrong format."""
128+
malformed_patterns = [
129+
r"<\s*think\s*>", # < think >
130+
r"<\s*answer\s*>", # < answer >
131+
r"</\s*think\s*>", # </ think >
132+
r"</\s*answer\s*>", # </ answer >
133+
]
134+
135+
for pattern in malformed_patterns:
136+
if re.search(pattern, content):
137+
return True
138+
return False
139+
140+
def _step(
141+
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
142+
) -> TensorDictBase:
143+
"""Process the tensordict and add enhanced prompts based on the condition."""
144+
# Handle batch dimensions
145+
if next_tensordict.batch_dims >= 1:
146+
ntds = []
147+
for td, next_td in zip(tensordict.unbind(0), next_tensordict.unbind(0)):
148+
ntds.append(self._step(td, next_td))
149+
next_tensordict.update(lazy_stack(ntds))
150+
return next_tensordict
151+
152+
# Check that base_env is on history mode
153+
parent = self.parent
154+
if parent is None:
155+
raise RuntimeError("EnhancedReasoningTransform must be used with a ChatEnv")
156+
base_env = parent.base_env
157+
if base_env.input_mode != "history":
158+
raise RuntimeError(
159+
"EnhancedReasoningTransform must be used with a ChatEnv in history mode"
160+
)
161+
162+
# Check if we should add the prompt
163+
if self.cond(next_tensordict):
164+
torchrl_logger.info(f"Adding enhanced reasoning prompt with strategy: {self.strategy}")
165+
166+
history: History = next_tensordict["history"].prompt
167+
last_turn = history[..., -1]
168+
169+
# Analyze the last response to determine the best strategy
170+
if self.strategy == "user_guidance":
171+
# Use user guidance strategy - add as a user message
172+
prompt = self._get_prompt()
173+
history = history.append(History(role="user", content=prompt))
174+
next_tensordict["history"].prompt = history
175+
176+
elif self.strategy == "thinking":
177+
# Use thinking strategy - add as assistant message
178+
prompt = self._get_prompt()
179+
history = history.append(History(role="assistant", content=prompt))
180+
next_tensordict["history"].prompt = history
181+
182+
else:
183+
# For other strategies, use user guidance as default
184+
prompt = self._get_prompt()
185+
history = history.append(History(role="user", content=prompt))
186+
next_tensordict["history"].prompt = history
187+
188+
# Undo done flag if requested
189+
if self.undo_done:
190+
parent: EnvBase = self.parent
191+
if parent is not None:
192+
done_keys = parent.done_keys
193+
for key in done_keys:
194+
done = next_tensordict.get(key)
195+
if done is not None:
196+
next_tensordict.set(key, done.zero_())
197+
198+
# Zero out reward if requested
199+
if self.zero_reward:
200+
parent: EnvBase = self.parent
201+
if parent is not None:
202+
reward_keys = parent.reward_keys
203+
for key in reward_keys:
204+
reward = next_tensordict.get(key)
205+
if reward is not None:
206+
next_tensordict.set(key, reward.zero_())
207+
else:
208+
torchrl_logger.info("Not adding enhanced reasoning prompt.")
209+
210+
return next_tensordict
211+
212+
def _reset(
213+
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
214+
) -> TensorDictBase:
215+
"""Reset the transform state."""
216+
return tensordict_reset

0 commit comments

Comments
 (0)