Skip to content

Commit 0b658bf

Browse files
committed
amend
1 parent a6f9111 commit 0b658bf

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

sota-implementations/grpo/grpo_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,9 +549,10 @@ def make_env(cfg: DictConfig, devices: list[int] | None = None):
549549
cond=lambda td: td["reward"] <= reward_threshold
550550
and td["step_count"] < max_steps,
551551
role="assistant",
552-
edit_last_turn=False,
553-
zero_reward=False,
552+
edit_last_turn=True,
553+
zero_reward=True,
554554
undo_done=True,
555+
random_prompt=True,
555556
),
556557
)
557558
env = env.append_transform(

torchrl/envs/llm/transforms/reason.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,18 +238,31 @@ def _replace_answer_with_prompt(self, content: str) -> str:
238238
The modified content with the answer replaced by the thinking prompt
239239
"""
240240
# Pattern to match <answer>...</answer> with optional EOS token
241+
# Use non-greedy matching and be more specific about the end
241242
answer_pattern = r"<answer>.*?</answer>(?:\s*<\|im_end\|>)?"
242243

243244
# Check if there's an answer tag
244245
if "<answer>" in content:
245246
# Replace the answer section with the thinking prompt
246247
prompt = self.prompt
247248

248-
# Replace the answer section
249+
# Replace the answer section, but preserve the EOS token if it exists
249250
modified_content = re.sub(answer_pattern, prompt, content, flags=re.DOTALL)
250251

251252
# Clean up any trailing whitespace
252253
modified_content = modified_content.rstrip()
254+
255+
# Ensure we end with the EOS token if the original content had it
256+
if content.endswith("<|im_end|>"):
257+
modified_content = modified_content.rstrip() + "<|im_end|>"
258+
259+
# Ensure proper spacing around the prompt
260+
if not modified_content.endswith(prompt):
261+
# If the prompt wasn't properly inserted, append it
262+
modified_content = content.rstrip()
263+
if modified_content.endswith("<|im_end|>"):
264+
modified_content = modified_content[:-len("<|im_end|>")].rstrip()
265+
modified_content = modified_content + "\n\n" + prompt + "<|im_end|>"
253266

254267
else:
255268
# No answer tag found, just append the prompt

0 commit comments

Comments
 (0)