Skip to content

Commit 96c3003

Browse files
author
Vincent Moens
committed
[BugFix] Fix KL penalty
ghstack-source-id: 475dccb Pull Request resolved: #2908
1 parent 3a9f244 commit 96c3003

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torchrl/envs/transforms/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,7 @@ def _step(
765765
kl = curr_log_prob - log_prob
766766
if reward is None:
767767
reward = 0
768-
next_tensordict.set(self.out_keys[0], reward + self.coef * kl)
768+
next_tensordict.set(self.out_keys[0], reward - self.coef * kl)
769769
return next_tensordict
770770

771771
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:

0 commit comments

Comments
 (0)