@@ -542,7 +542,7 @@ def forward(
542
542
>>> reward = torch.randn(1, 10, 1)
543
543
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
544
544
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
545
- >>> advantage, value_target = module(obs=obs, reward =reward, done =done, next_obs=next_obs, terminated =terminated)
545
+ >>> advantage, value_target = module(obs=obs, next_reward =reward, next_done =done, next_obs=next_obs, next_terminated =terminated)
546
546
547
547
"""
548
548
if tensordict .batch_dims < 1 :
@@ -743,7 +743,7 @@ def forward(
743
743
>>> reward = torch.randn(1, 10, 1)
744
744
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
745
745
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
746
- >>> advantage, value_target = module(obs=obs, reward =reward, done =done, next_obs=next_obs, terminated =terminated)
746
+ >>> advantage, value_target = module(obs=obs, next_reward =reward, next_done =done, next_obs=next_obs, next_terminated =terminated)
747
747
748
748
"""
749
749
if tensordict .batch_dims < 1 :
@@ -955,7 +955,7 @@ def forward(
955
955
>>> reward = torch.randn(1, 10, 1)
956
956
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
957
957
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
958
- >>> advantage, value_target = module(obs=obs, reward =reward, done =done, next_obs=next_obs, terminated =terminated)
958
+ >>> advantage, value_target = module(obs=obs, next_reward =reward, next_done =done, next_obs=next_obs, next_terminated =terminated)
959
959
960
960
"""
961
961
if tensordict .batch_dims < 1 :
@@ -1198,7 +1198,7 @@ def forward(
1198
1198
>>> reward = torch.randn(1, 10, 1)
1199
1199
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
1200
1200
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
1201
- >>> advantage, value_target = module(obs=obs, reward =reward, done =done, next_obs=next_obs, terminated =terminated)
1201
+ >>> advantage, value_target = module(obs=obs, next_reward =reward, next_done =done, next_obs=next_obs, next_terminated =terminated)
1202
1202
1203
1203
"""
1204
1204
if tensordict .batch_dims < 1 :
0 commit comments