Skip to content

Commit e1620eb

Browse files
authored
[BugFix] Step and maybe reset (#938)
1 parent f47d2cb commit e1620eb

File tree

3 files changed

+153
-6
lines changed

3 files changed

+153
-6
lines changed

test/mocking_classes.py

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ def custom_td(self):
9797

9898

9999
class MockSerialEnv(EnvBase):
100+
"""A simple counting env that is reset after a predifined max number of steps."""
101+
100102
@classmethod
101103
def __new__(
102104
cls,
@@ -844,9 +846,16 @@ def forward(self, observation, action):
844846

845847

846848
class CountingEnv(EnvBase):
847-
def __init__(self, max_steps: int = 5, **kwargs):
849+
"""An env that is done after a given number of steps.
850+
851+
The action is the count increment.
852+
853+
"""
854+
855+
def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
848856
super().__init__(**kwargs)
849857
self.max_steps = max_steps
858+
self.start_val = start_val
850859

851860
self.observation_spec = CompositeSpec(
852861
observation=UnboundedContinuousTensorSpec(
@@ -878,9 +887,9 @@ def _set_seed(self, seed: Optional[int]):
878887
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
879888
if tensordict is not None and "_reset" in tensordict.keys():
880889
_reset = tensordict.get("_reset")
881-
self.count[_reset] = 0
890+
self.count[_reset] = self.start_val
882891
else:
883-
self.count[:] = 0
892+
self.count[:] = self.start_val
884893
return TensorDict(
885894
source={
886895
"observation": self.count.clone(),
@@ -905,3 +914,87 @@ def _step(
905914
batch_size=self.batch_size,
906915
device=self.device,
907916
)
917+
918+
919+
class CountingBatchedEnv(EnvBase):
920+
"""An env that is done after a given number of steps.
921+
922+
The action is the count increment.
923+
924+
Unlike ``CountingEnv``, different envs of the batch can have different max_steps
925+
"""
926+
927+
def __init__(
928+
self,
929+
max_steps: torch.Tensor = None,
930+
start_val: torch.Tensor = None,
931+
**kwargs,
932+
):
933+
super().__init__(**kwargs)
934+
if max_steps is None:
935+
max_steps = torch.tensor(5)
936+
if start_val is None:
937+
start_val = torch.zeros(())
938+
if not max_steps.shape == self.batch_size:
939+
raise RuntimeError("batch_size and max_steps shape must match.")
940+
941+
self.max_steps = max_steps
942+
self.start_val = start_val
943+
944+
self.observation_spec = CompositeSpec(
945+
observation=UnboundedContinuousTensorSpec(
946+
(
947+
*self.batch_size,
948+
1,
949+
)
950+
),
951+
shape=self.batch_size,
952+
)
953+
self.reward_spec = UnboundedContinuousTensorSpec(
954+
(
955+
*self.batch_size,
956+
1,
957+
)
958+
)
959+
self.input_spec = CompositeSpec(
960+
action=BinaryDiscreteTensorSpec(n=1, shape=[*self.batch_size, 1]),
961+
shape=self.batch_size,
962+
)
963+
964+
self.count = torch.zeros(
965+
(*self.batch_size, 1), device=self.device, dtype=torch.int
966+
)
967+
968+
def _set_seed(self, seed: Optional[int]):
969+
torch.manual_seed(seed)
970+
971+
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
972+
if tensordict is not None and "_reset" in tensordict.keys():
973+
_reset = tensordict.get("_reset")
974+
self.count[_reset] = self.start_val[_reset].unsqueeze(-1)
975+
else:
976+
self.count[:] = self.start_val.unsqueeze(-1)
977+
return TensorDict(
978+
source={
979+
"observation": self.count.clone(),
980+
"done": self.count > self.max_steps.unsqueeze(-1),
981+
},
982+
batch_size=self.batch_size,
983+
device=self.device,
984+
)
985+
986+
def _step(
987+
self,
988+
tensordict: TensorDictBase,
989+
) -> TensorDictBase:
990+
action = tensordict.get("action")
991+
self.count += action.to(torch.int).unsqueeze(-1)
992+
return TensorDict(
993+
source={
994+
"observation": self.count,
995+
"done": self.count > self.max_steps.unsqueeze(-1),
996+
"reward": torch.zeros_like(self.count, dtype=torch.float),
997+
},
998+
batch_size=self.batch_size,
999+
device=self.device,
1000+
)

test/test_collector.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from _utils_internal import generate_seeds, PENDULUM_VERSIONED, PONG_VERSIONED
1313
from mocking_classes import (
1414
ContinuousActionVecMockEnv,
15+
CountingBatchedEnv,
16+
CountingEnv,
1517
DiscreteActionConvMockEnv,
1618
DiscreteActionConvPolicy,
1719
DiscreteActionVecMockEnv,
@@ -1181,6 +1183,52 @@ def test_auto_wrap_error(self, collector_class, env_maker):
11811183
)
11821184

11831185

1186+
@pytest.mark.parametrize("env_class", [CountingEnv, CountingBatchedEnv])
1187+
def test_initial_obs_consistency(env_class, seed=1):
1188+
torch.manual_seed(seed)
1189+
start_val = 4
1190+
if env_class == CountingEnv:
1191+
num_envs = 1
1192+
env = CountingEnv(device="cpu", max_steps=8, start_val=start_val)
1193+
max_steps = 8
1194+
elif env_class == CountingBatchedEnv:
1195+
num_envs = 2
1196+
env = CountingBatchedEnv(
1197+
device="cpu",
1198+
batch_size=[num_envs],
1199+
max_steps=torch.arange(num_envs) + 17,
1200+
start_val=torch.ones([num_envs]) * start_val,
1201+
)
1202+
max_steps = env.max_steps.max().item()
1203+
env.set_seed(seed)
1204+
policy = lambda tensordict: tensordict.set(
1205+
"action", torch.ones(tensordict.shape, dtype=torch.int)
1206+
)
1207+
collector = SyncDataCollector(
1208+
create_env_fn=env,
1209+
policy=policy,
1210+
frames_per_batch=((max_steps - 3) * 2 + 2) * num_envs, # at least two episodes
1211+
split_trajs=False,
1212+
)
1213+
for _d in collector:
1214+
break
1215+
obs = _d["observation"].squeeze()
1216+
if env_class == CountingEnv:
1217+
arange_0 = start_val + torch.arange(max_steps - 3)
1218+
arange = start_val + torch.arange(2)
1219+
expected = torch.cat([arange_0, arange_0, arange]).float()
1220+
else:
1221+
# the first env has a shorter horizon than the second
1222+
arange_0 = start_val + torch.arange(max_steps - 3 - 1)
1223+
arange = start_val + torch.arange(start_val)
1224+
expected_0 = torch.cat([arange_0, arange_0, arange]).float()
1225+
arange_0 = start_val + torch.arange(max_steps - 3)
1226+
arange = start_val + torch.arange(2)
1227+
expected_1 = torch.cat([arange_0, arange_0, arange]).float()
1228+
expected = torch.stack([expected_0, expected_1])
1229+
assert torch.allclose(obs, expected)
1230+
1231+
11841232
def weight_reset(m):
11851233
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
11861234
m.reset_parameters()

torchrl/collectors/collectors.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
571571
if self._frames >= self.total_frames:
572572
break
573573

574-
def _reset_if_necessary(self) -> None:
574+
def _step_and_maybe_reset(self) -> None:
575575
done = self._tensordict.get("done")
576576
if not self.reset_when_done:
577577
done = torch.zeros_like(done)
@@ -592,6 +592,11 @@ def _reset_if_necessary(self) -> None:
592592
done_or_terminated = done_or_terminated | _reset
593593

594594
if done_or_terminated.any():
595+
if not done_or_terminated.all():
596+
self._tensordict[~done_or_terminated] = step_mdp(
597+
self._tensordict[~done_or_terminated]
598+
)
599+
595600
traj_ids = self._tensordict.get(("collector", "traj_ids")).clone()
596601
steps = steps.clone()
597602
if len(self.env.batch_size):
@@ -617,6 +622,8 @@ def _reset_if_necessary(self) -> None:
617622
("collector", "traj_ids"), traj_ids
618623
) # no ops if they already match
619624
self._tensordict.set_(("collector", "step_count"), steps)
625+
else:
626+
self._tensordict.update(step_mdp(self._tensordict), inplace=True)
620627

621628
@torch.no_grad()
622629
def rollout(self) -> TensorDictBase:
@@ -651,8 +658,7 @@ def rollout(self) -> TensorDictBase:
651658
if is_shared:
652659
self._tensordict_out.share_memory_()
653660

654-
self._reset_if_necessary()
655-
self._tensordict.update(step_mdp(self._tensordict), inplace=True)
661+
self._step_and_maybe_reset()
656662

657663
return self._tensordict_out
658664

0 commit comments

Comments
 (0)