Skip to content

Commit 4b1ad2b

Browse files
authored
[Doc, BugFix] Fix tutos errors (#817)
1 parent d4d45ec commit 4b1ad2b

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

tutorials/sphinx-tutorials/coding_dqn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,8 @@ def make_model():
378378
frames.append(current_frames)
379379

380380
if data["done"].any():
381-
traj_lengths.append(data["step_count"][data["done"]].float().mean().item())
381+
done = data["done"].squeeze(-1)
382+
traj_lengths.append(data["step_count"][done].float().mean().item())
382383

383384
# check that we have enough data to start training
384385
if sum(frames) > init_random_frames:
@@ -612,7 +613,8 @@ def make_model():
612613
frames.append(current_frames)
613614

614615
if data["done"].any():
615-
traj_lengths.append(data["step_count"][data["done"]].float().mean().item())
616+
done = data["done"].squeeze(-1)
617+
traj_lengths.append(data["step_count"][done].float().mean().item())
616618

617619
if sum(frames) > init_random_frames:
618620
for _ in range(n_optim):

tutorials/sphinx-tutorials/torchrl_demo.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@
330330
Compose,
331331
NoopResetEnv,
332332
ObservationNorm,
333+
StepCounter,
333334
ToTensorImage,
334335
TransformedEnv,
335336
)
@@ -358,7 +359,7 @@
358359
lambda: GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False),
359360
)
360361
env = TransformedEnv(
361-
base_env, Compose(NoopResetEnv(3), ToTensorImage())
362+
base_env, Compose(StepCounter(), ToTensorImage())
362363
) # applies transforms on batch of envs
363364
env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))
364365
env.reset()
@@ -587,9 +588,9 @@
587588
for i in range(max_steps):
588589
actor(tensordict)
589590
tensordicts[i] = env.step(tensordict)
590-
tensordict = step_mdp(tensordict) # roughly equivalent to obs = next_obs
591-
if env.is_done:
591+
if tensordict["done"].any():
592592
break
593+
tensordict = step_mdp(tensordict) # roughly equivalent to obs = next_obs
593594

594595
tensordicts_prealloc = tensordicts.clone()
595596
print("total steps:", i)
@@ -607,9 +608,9 @@
607608
for _ in range(max_steps):
608609
actor(tensordict)
609610
tensordicts.append(env.step(tensordict))
610-
tensordict = step_mdp(tensordict) # roughly equivalent to obs = next_obs
611-
if env.is_done:
611+
if tensordict["done"].any():
612612
break
613+
tensordict = step_mdp(tensordict) # roughly equivalent to obs = next_obs
613614
tensordicts_stack = torch.stack(tensordicts, 0)
614615
print("total steps:", i)
615616
print(tensordicts_stack)

0 commit comments

Comments
 (0)