Skip to content

Commit e2d5dbe

Browse files
authored
[Doc] Tutorial revamp (#926)
1 parent 099ced3 commit e2d5dbe

File tree

15 files changed

+961
-1971
lines changed

15 files changed

+961
-1971
lines changed

.github/workflows/docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ jobs:
8585
- name: Get output time
8686
run: echo "The time was ${{ steps.build.outputs.time }}"
8787
- name: Deploy
88-
if: ${{ github.ref == 'refs/heads/main' }}
88+
if: ${{ github.ref == 'refs/heads/main' || github.event_name == 'workflow_dispatch' }}
8989
uses: JamesIves/github-pages-deploy-action@releases/v4
9090
with:
9191
token: ${{ secrets.GITHUB_TOKEN }}

docs/source/_static/img/cartpole.gif

567 KB
Loading
57.2 KB
Loading

docs/source/_static/img/dqn.png

204 KB
Loading

docs/source/_static/img/dqn_td0.png

175 KB
Loading
163 KB
Loading

docs/source/_static/img/pendulum.gif

122 KB
Loading
31.8 KB
Loading

torchrl/data/replay_buffers/storages.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -367,10 +367,9 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
367367
)
368368
elif is_tensorclass(data):
369369
out = (
370-
data.expand(self.max_size, *data.shape)
371-
.clone()
372-
.zero_()
373-
.memmap_(prefix=self.scratch_dir)
370+
data.clone()
371+
.expand(self.max_size, *data.shape)
372+
.memmap_like(prefix=self.scratch_dir)
374373
.to(self.device)
375374
)
376375
for key, tensor in sorted(
@@ -384,10 +383,9 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
384383
# out = TensorDict({}, [self.max_size, *data.shape])
385384
print("The storage is being created: ")
386385
out = (
387-
data.expand(self.max_size, *data.shape)
388-
.to_tensordict()
389-
.zero_()
390-
.memmap_(prefix=self.scratch_dir)
386+
data.clone()
387+
.expand(self.max_size, *data.shape)
388+
.memmap_like(prefix=self.scratch_dir)
391389
.to(self.device)
392390
)
393391
for key, tensor in sorted(

torchrl/envs/common.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -584,23 +584,24 @@ def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None:
584584
f"got {tensordict.batch_size} and {self.batch_size}"
585585
)
586586

587-
def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase:
588-
"""Performs a random step in the environment given the action_spec attribute.
587+
def rand_action(self, tensordict: Optional[TensorDictBase] = None):
588+
"""Performs a random action given the action_spec attribute.
589589
590590
Args:
591-
tensordict (TensorDictBase, optional): tensordict where the resulting info should be written.
591+
tensordict (TensorDictBase, optional): tensordict where the resulting action should be written.
592592
593593
Returns:
594-
a tensordict object with the new observation after a random step in the environment. The action will
595-
be stored with the "action" key.
594+
a tensordict object with the "action" entry updated with a random
595+
sample from the action-spec.
596596
597597
"""
598598
shape = torch.Size([])
599599
if tensordict is None:
600600
tensordict = TensorDict(
601601
{}, device=self.device, batch_size=self.batch_size, _run_checks=False
602602
)
603-
elif not self.batch_locked and not self.batch_size:
603+
604+
if not self.batch_locked and not self.batch_size:
604605
shape = tensordict.shape
605606
elif not self.batch_locked and tensordict.shape != self.batch_size:
606607
raise RuntimeError(
@@ -611,6 +612,20 @@ def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBa
611612
)
612613
action = self.action_spec.rand(shape)
613614
tensordict.set("action", action)
615+
return tensordict
616+
617+
def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase:
618+
"""Performs a random step in the environment given the action_spec attribute.
619+
620+
Args:
621+
tensordict (TensorDictBase, optional): tensordict where the resulting info should be written.
622+
623+
Returns:
624+
a tensordict object with the new observation after a random step in the environment. The action will
625+
be stored with the "action" key.
626+
627+
"""
628+
tensordict = self.rand_action(tensordict)
614629
return self.step(tensordict)
615630

616631
@property
@@ -680,7 +695,7 @@ def rollout(
680695
if policy is None:
681696

682697
def policy(td):
683-
self.rand_step(td)
698+
self.rand_action(td)
684699
return td
685700

686701
tensordicts = []
@@ -796,16 +811,18 @@ def to(self, device: DEVICE_TYPING) -> EnvBase:
796811
def fake_tensordict(self) -> TensorDictBase:
797812
"""Returns a fake tensordict with key-value pairs that match in shape, device and dtype what can be expected during an environment rollout."""
798813
input_spec = self.input_spec
799-
fake_input = input_spec.zero()
800814
observation_spec = self.observation_spec
801815
fake_obs = observation_spec.zero()
816+
fake_input = input_spec.zero()
817+
# the input and output key may match, but the output prevails
818+
# Hence we generate the input, and override using the output
819+
fake_in_out = fake_input.clone().update(fake_obs)
802820
reward_spec = self.reward_spec
803821
fake_reward = reward_spec.zero()
804822
fake_td = TensorDict(
805823
{
806-
**fake_obs,
824+
**fake_in_out,
807825
"next": fake_obs.clone(),
808-
**fake_input,
809826
"reward": fake_reward,
810827
"done": torch.zeros(
811828
(*self.batch_size, 1), dtype=torch.bool, device=self.device

0 commit comments

Comments
 (0)