Skip to content

Commit 2505ede

Browse files
committed
amend
1 parent 58a890f commit 2505ede

File tree

5 files changed

+16
-237
lines changed

5 files changed

+16
-237
lines changed

sota-implementations/grpo/grpo_utils.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@
1212

1313
from torchrl._utils import logger as torchrl_logger
1414
from torchrl.collectors.llm.weight_update.vllm import vLLMUpdater
15-
from torchrl.envs.llm import GSM8KEnv, KLRewardTransform, RetrieveKL
15+
from torchrl.envs.llm import GSM8KEnv, KLRewardTransform, RetrieveKL, AddThinkingPrompt
1616
from torchrl.envs.llm.datasets.ifeval import IFEvalEnv
17-
from torchrl.envs.llm.transforms.enhanced_reasoning import EnhancedReasoningTransform
1817
from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
1918
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
2019
from transformers.tokenization_utils import PreTrainedTokenizer
@@ -524,6 +523,7 @@ def make_env(cfg: DictConfig, devices: list[int] | None = None):
524523
max_steps = cfg.env.max_steps if cfg.env.reasoning else 1
525524
if cfg.env.dataset == "gsm8k":
526525
# Reward scale is 0.0 to 100
526+
reward_threshold=20
527527
env = GSM8KEnv(
528528
repeats=cfg.env.repeats,
529529
tokenizer=train_tokenizer,
@@ -533,6 +533,7 @@ def make_env(cfg: DictConfig, devices: list[int] | None = None):
533533
)
534534
elif cfg.env.dataset == "ifeval": # ifeval
535535
# Reward scale is 0.0 to 2.2
536+
reward_threshold=1.0
536537
env = IFEvalEnv(
537538
repeats=cfg.env.repeats,
538539
tokenizer=train_tokenizer,
@@ -544,24 +545,14 @@ def make_env(cfg: DictConfig, devices: list[int] | None = None):
544545
raise NotImplementedError(f"Dataset {cfg.env.dataset} not implemented")
545546
if cfg.env.reasoning:
546547
env = env.append_transform(
547-
# AddThinkingPrompt(
548-
# cond=lambda td: td["reward"] <= reward_threshold
549-
# and td["step_count"] < max_steps,
550-
# role="assistant",
551-
# edit_last_turn=True,
552-
# zero_reward=True,
553-
# undo_done=True,
554-
# random_prompt=True,
555-
# ),
556-
EnhancedReasoningTransform(
557-
cond=lambda td: td["reward"] <= 1.0 and td["step_count"] < 3,
558-
strategy="user_guidance", # User tells assistant to reconsider
559-
reward_threshold=1.0,
560-
max_steps=3,
548+
AddThinkingPrompt(
549+
cond=lambda td, reward_threshol=reward_threshold, max_steps=max_steps: td["reward"] <= reward_threshold and td["step_count"] < max_steps,
550+
role="assistant",
551+
edit_last_turn=True,
561552
zero_reward=True,
562553
undo_done=True,
563554
random_prompt=True,
564-
)
555+
),
565556
)
566557
env = env.append_transform(
567558
# RetrieveKL will be lazily initialized in the collector.

test/llm/test_wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,8 +1366,8 @@ def test_kl_computation_transform(
13661366

13671367
# Create KLComputation transform
13681368
kl_transform = KLComputation(
1369-
gen_log_probs_key=("gen_log_probs", "full"),
1370-
ref_log_probs_key=("ref_log_probs", "full"),
1369+
gen_log_probs_full_key=("gen_log_probs", "full"),
1370+
ref_log_probs_full_key=("ref_log_probs", "full"),
13711371
kl_key="kl",
13721372
add_to_reward=True,
13731373
coeff=1.0,

torchrl/envs/llm/transforms/enhanced_reasoning.py

Lines changed: 0 additions & 216 deletions
This file was deleted.

torchrl/envs/llm/transforms/kl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -859,7 +859,7 @@ def __init__(
859859
)
860860
t1 = RetrieveLogProb(
861861
gen_model,
862-
log_probs_key=gen_log_probs_full_key,
862+
log_probs_full_key=gen_log_probs_full_key,
863863
assistant_only=assistant_only,
864864
tokenizer_kwargs=tokenizer_kwargs,
865865
detach=detach,
@@ -870,7 +870,7 @@ def __init__(
870870
)
871871
t2 = RetrieveLogProb(
872872
ref_model,
873-
log_probs_key=ref_log_probs_full_key,
873+
log_probs_full_key=ref_log_probs_full_key,
874874
assistant_only=assistant_only,
875875
tokenizer_kwargs=tokenizer_kwargs,
876876
detach=detach,

torchrl/objectives/llm/grpo.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,10 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
261261
raise ValueError(
262262
f"advantage and log_weight must have the same number of dimensions, got {advantage.ndim=} and {log_weight.ndim=}"
263263
)
264+
print(f"log_weight: {log_weight.shape}")
265+
print(f"advantage: {advantage.shape}")
266+
print(f"mask: {mask.shape}")
267+
print(f"data: {tensordict}")
264268
gain1 = log_weight.exp() * advantage
265269

266270
log_weight_clip = log_weight.clamp(*self._clip_bounds)

0 commit comments

Comments
 (0)