File tree Expand file tree Collapse file tree 2 files changed +22
-12
lines changed
sota-implementations/grpo
torchrl/envs/llm/transforms Expand file tree Collapse file tree 2 files changed +22
-12
lines changed Original file line number Diff line number Diff line change 14
14
from torchrl .collectors .llm .weight_update .vllm import vLLMUpdater
15
15
from torchrl .envs .llm import AddThinkingPrompt , GSM8KEnv , KLRewardTransform , RetrieveKL
16
16
from torchrl .envs .llm .datasets .ifeval import IFEvalEnv
17
+ from torchrl .envs .llm .transforms .enhanced_reasoning import EnhancedReasoningTransform
17
18
from torchrl .modules .llm import TransformersWrapper , vLLMWrapper
18
19
from transformers .models .auto .modeling_auto import AutoModelForCausalLM
19
20
from transformers .tokenization_utils import PreTrainedTokenizer
@@ -545,15 +546,24 @@ def make_env(cfg: DictConfig, devices: list[int] | None = None):
545
546
raise NotImplementedError (f"Dataset { cfg .env .dataset } not implemented" )
546
547
if cfg .env .reasoning :
547
548
env = env .append_transform (
548
- AddThinkingPrompt (
549
- cond = lambda td : td ["reward" ] <= reward_threshold
550
- and td ["step_count" ] < max_steps ,
551
- role = "assistant" ,
552
- edit_last_turn = True ,
553
- zero_reward = True ,
554
- undo_done = True ,
555
- random_prompt = True ,
556
- ),
549
+ # AddThinkingPrompt(
550
+ # cond=lambda td: td["reward"] <= reward_threshold
551
+ # and td["step_count"] < max_steps,
552
+ # role="assistant",
553
+ # edit_last_turn=True,
554
+ # zero_reward=True,
555
+ # undo_done=True,
556
+ # random_prompt=True,
557
+ # ),
558
+ EnhancedReasoningTransform (
559
+ cond = lambda td : td ["reward" ] <= 1.0 and td ["step_count" ] < 3 ,
560
+ strategy = "user_guidance" , # User tells assistant to reconsider
561
+ reward_threshold = 1.0 ,
562
+ max_steps = 3 ,
563
+ zero_reward = True ,
564
+ undo_done = True ,
565
+ random_prompt = True ,
566
+ )
557
567
)
558
568
env = env .append_transform (
559
569
# RetrieveKL will be lazily initialized in the collector.
Original file line number Diff line number Diff line change @@ -251,17 +251,17 @@ def _replace_answer_with_prompt(self, content: str) -> str:
251
251
252
252
# Clean up any trailing whitespace
253
253
modified_content = modified_content .rstrip ()
254
-
254
+
255
255
# Ensure we end with the EOS token if the original content had it
256
256
if content .endswith ("<|im_end|>" ):
257
257
modified_content = modified_content .rstrip () + "<|im_end|>"
258
-
258
+
259
259
# Ensure proper spacing around the prompt
260
260
if not modified_content .endswith (prompt ):
261
261
# If the prompt wasn't properly inserted, append it
262
262
modified_content = content .rstrip ()
263
263
if modified_content .endswith ("<|im_end|>" ):
264
- modified_content = modified_content [:- len ("<|im_end|>" )].rstrip ()
264
+ modified_content = modified_content [: - len ("<|im_end|>" )].rstrip ()
265
265
modified_content = modified_content + "\n \n " + prompt + "<|im_end|>"
266
266
267
267
else :
You can’t perform that action at this time.
0 commit comments