12
12
13
13
from torchrl ._utils import logger as torchrl_logger
14
14
from torchrl .collectors .llm .weight_update .vllm import vLLMUpdater
15
- from torchrl .envs .llm import GSM8KEnv , KLRewardTransform , RetrieveKL , AddThinkingPrompt
15
+ from torchrl .envs .llm import AddThinkingPrompt , GSM8KEnv , KLRewardTransform , RetrieveKL
16
16
from torchrl .envs .llm .datasets .ifeval import IFEvalEnv
17
17
from torchrl .modules .llm import TransformersWrapper , vLLMWrapper
18
18
from transformers .models .auto .modeling_auto import AutoModelForCausalLM
@@ -523,7 +523,7 @@ def make_env(cfg: DictConfig, devices: list[int] | None = None):
523
523
max_steps = cfg .env .max_steps if cfg .env .reasoning else 1
524
524
if cfg .env .dataset == "gsm8k" :
525
525
# Reward scale is 0.0 to 100
526
- reward_threshold = 20
526
+ reward_threshold = 20
527
527
env = GSM8KEnv (
528
528
repeats = cfg .env .repeats ,
529
529
tokenizer = train_tokenizer ,
@@ -533,7 +533,7 @@ def make_env(cfg: DictConfig, devices: list[int] | None = None):
533
533
)
534
534
elif cfg .env .dataset == "ifeval" : # ifeval
535
535
# Reward scale is 0.0 to 2.2
536
- reward_threshold = 1.0
536
+ reward_threshold = 1.0
537
537
env = IFEvalEnv (
538
538
repeats = cfg .env .repeats ,
539
539
tokenizer = train_tokenizer ,
@@ -546,7 +546,11 @@ def make_env(cfg: DictConfig, devices: list[int] | None = None):
546
546
if cfg .env .reasoning :
547
547
env = env .append_transform (
548
548
AddThinkingPrompt (
549
- cond = lambda td , reward_threshol = reward_threshold , max_steps = max_steps : td ["reward" ] <= reward_threshold and td ["step_count" ] < max_steps ,
549
+ cond = lambda td , reward_threshol = reward_threshold , max_steps = max_steps : td [
550
+ "reward"
551
+ ]
552
+ <= reward_threshold
553
+ and td ["step_count" ] < max_steps ,
550
554
role = "assistant" ,
551
555
edit_last_turn = True ,
552
556
zero_reward = True ,
0 commit comments