Skip to content

Commit 9ccae47

Browse files
author
Vincent Moens
authored
[Doc] Fix advantage examples (#1600)
1 parent a43612a commit 9ccae47

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

torchrl/objectives/value/advantages.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ def forward(
542542
>>> reward = torch.randn(1, 10, 1)
543543
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
544544
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
545-
>>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs, terminated=terminated)
545+
>>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated)
546546
547547
"""
548548
if tensordict.batch_dims < 1:
@@ -743,7 +743,7 @@ def forward(
743743
>>> reward = torch.randn(1, 10, 1)
744744
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
745745
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
746-
>>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs, terminated=terminated)
746+
>>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated)
747747
748748
"""
749749
if tensordict.batch_dims < 1:
@@ -955,7 +955,7 @@ def forward(
955955
>>> reward = torch.randn(1, 10, 1)
956956
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
957957
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
958-
>>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs, terminated=terminated)
958+
>>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated)
959959
960960
"""
961961
if tensordict.batch_dims < 1:
@@ -1198,7 +1198,7 @@ def forward(
11981198
>>> reward = torch.randn(1, 10, 1)
11991199
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
12001200
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
1201-
>>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs, terminated=terminated)
1201+
>>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated)
12021202
12031203
"""
12041204
if tensordict.batch_dims < 1:

0 commit comments

Comments
 (0)