Skip to content

Commit 0e3f066

Browse files
vmoensromainjln
andauthored
[Refactor] Refactor 'next_' into nested tensordicts (#649)
* init * [Feature] Nested composite spec (#654) * [Feature] Move `transform.forward` to `transform.step` (#660) * transform step function * amend * amend * amend * amend * amend * fixing key names * fixing key names * [Refactor] Transform next remove (#661) * Refactor "next_" into ("next", ) (#673) * amend * amend * bugfix * init * strict=False * strict=False * minor * amend * [BugFix] Use GitHub for flake8 pre-commit hook (#679) * amend * [BugFix] Update to strict select (#675) * init * strict=False * amend * amend * [Feature] Auto-compute stats for ObservationNorm (#669) * Add auto-compute stats feature for ObservationNorm * Fix issue in ObservNorm init function * Quick refactor of ObservationNorm init method * Minor refactoring and adding more tests for ObservationNorm * lint * docstring * docstring Co-authored-by: vmoens <vincentmoens@gmail.com> * amend * amend * lint * bf * bf * amend Co-authored-by: Romain Julien <romainjulien@fb.com> Co-authored-by: Romain Julien <romainjulien@fb.com>
1 parent 354c198 commit 0e3f066

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+893
-720
lines changed

examples/dreamer/dreamer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,7 @@ def main(cfg: "DictConfig"): # noqa: F821
288288
):
289289
sampled_tensordict_save = (
290290
sampled_tensordict.select(
291-
"next_pixels",
292-
"next_reco_pixels",
293-
"state",
291+
"next" "state",
294292
"belief",
295293
)[:4]
296294
.detach()

examples/dreamer/dreamer_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,13 @@ def make_env_transforms(
9393
if cfg.grayscale:
9494
env.append_transform(GrayScale())
9595
env.append_transform(FlattenObservation())
96-
env.append_transform(CatFrames(N=cfg.catframes, in_keys=["next_pixels"]))
96+
env.append_transform(CatFrames(N=cfg.catframes, in_keys=["pixels"]))
9797
if stats is None:
9898
obs_stats = {"loc": 0.0, "scale": 1.0}
9999
else:
100100
obs_stats = stats
101101
obs_stats["standard_normal"] = True
102-
env.append_transform(ObservationNorm(**obs_stats, in_keys=["next_pixels"]))
102+
env.append_transform(ObservationNorm(**obs_stats, in_keys=["pixels"]))
103103
if norm_rewards:
104104
reward_scaling = 1.0
105105
reward_loc = 0.0
@@ -122,8 +122,8 @@ def make_env_transforms(
122122
)
123123

124124
default_dict = {
125-
"next_state": NdUnboundedContinuousTensorSpec(cfg.state_dim),
126-
"next_belief": NdUnboundedContinuousTensorSpec(cfg.rssm_hidden_dim),
125+
"state": NdUnboundedContinuousTensorSpec(cfg.state_dim),
126+
"belief": NdUnboundedContinuousTensorSpec(cfg.rssm_hidden_dim),
127127
}
128128
env.append_transform(
129129
TensorDictPrimer(random=False, default_value=0, **default_dict)
@@ -309,7 +309,7 @@ def call_record(
309309

310310
true_pixels = recover_pixels(world_model_td["next_pixels"], stats)
311311

312-
reco_pixels = recover_pixels(world_model_td["next_reco_pixels"], stats)
312+
reco_pixels = recover_pixels(world_model_td["next", "reco_pixels"], stats)
313313
with autocast(dtype=torch.float16):
314314
world_model_td = world_model_td.select("state", "belief", "reward")
315315
world_model_td = model_based_env.rollout(
@@ -319,7 +319,7 @@ def call_record(
319319
tensordict=world_model_td[:, 0],
320320
)
321321
imagine_pxls = recover_pixels(
322-
model_based_env.decode_obs(world_model_td)["next_reco_pixels"],
322+
model_based_env.decode_obs(world_model_td)["next", "reco_pixels"],
323323
stats,
324324
)
325325

test/_utils_internal.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import torch.cuda
1414
from tensordict.tensordict import TensorDictBase
1515
from torchrl._utils import seed_generator
16-
from torchrl.data import CompositeSpec
1716
from torchrl.envs import EnvBase
1817

1918

@@ -62,21 +61,20 @@ def _test_fake_tensordict(env: EnvBase):
6261

6362

6463
def _check_dtype(key, value, obs_spec, input_spec):
65-
if key.startswith("next_"):
66-
return
67-
if isinstance(value, TensorDictBase):
64+
if isinstance(value, TensorDictBase) and key == "next":
6865
for _key, _value in value.items():
69-
if isinstance(obs_spec, CompositeSpec) and "next_" + key in obs_spec.keys():
70-
_check_dtype(_key, _value, obs_spec["next_" + key], input_spec=None)
71-
elif isinstance(input_spec, CompositeSpec) and key in input_spec.keys():
72-
_check_dtype(_key, _value, obs_spec=None, input_spec=input_spec[key])
73-
else:
74-
raise KeyError(f"key '{_key}' is unknown.")
66+
_check_dtype(_key, _value, obs_spec, input_spec=None)
67+
elif isinstance(value, TensorDictBase) and key in obs_spec.keys():
68+
for _key, _value in value.items():
69+
_check_dtype(_key, _value, obs_spec=obs_spec[key], input_spec=None)
70+
elif isinstance(value, TensorDictBase) and key in input_spec.keys():
71+
for _key, _value in value.items():
72+
_check_dtype(_key, _value, obs_spec=None, input_spec=input_spec[key])
7573
else:
76-
if obs_spec is not None and "next_" + key in obs_spec.keys():
74+
if obs_spec is not None and key in obs_spec.keys():
7775
assert (
78-
obs_spec["next_" + key].dtype is value.dtype
79-
), f"{obs_spec['next_' + key].dtype} vs {value.dtype} for {key}"
76+
obs_spec[key].dtype is value.dtype
77+
), f"{obs_spec[key].dtype} vs {value.dtype} for {key}"
8078
elif input_spec is not None and key in input_spec.keys():
8179
assert (
8280
input_spec[key].dtype is value.dtype
@@ -112,3 +110,11 @@ def f_retry(*args, **kwargs):
112110
return f_retry # true decorator
113111

114112
return deco_retry
113+
114+
115+
@pytest.fixture
116+
def dtype_fixture():
117+
dtype = torch.get_default_dtype()
118+
torch.set_default_dtype(torch.double)
119+
yield dtype
120+
torch.set_default_dtype(dtype)

test/mocking_classes.py

Lines changed: 32 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def __new__(
121121
action_spec = NdUnboundedContinuousTensorSpec((1,))
122122
if observation_spec is None:
123123
observation_spec = CompositeSpec(
124-
next_observation=NdUnboundedContinuousTensorSpec((1,))
124+
observation=NdUnboundedContinuousTensorSpec((1,))
125125
)
126126
if reward_spec is None:
127127
reward_spec = NdUnboundedContinuousTensorSpec((1,))
@@ -152,19 +152,17 @@ def _step(self, tensordict):
152152
)
153153
done = self.counter >= self.max_val
154154
done = torch.tensor([done], dtype=torch.bool, device=self.device)
155-
return TensorDict(
156-
{"reward": n, "done": done, "next_observation": n.clone()}, []
157-
)
155+
return TensorDict({"reward": n, "done": done, "observation": n.clone()}, [])
158156

159-
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
157+
def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase:
160158
self.max_val = max(self.counter + 100, self.counter * 2)
161159

162160
n = torch.tensor(
163161
[self.counter], device=self.device, dtype=torch.get_default_dtype()
164162
)
165163
done = self.counter >= self.max_val
166164
done = torch.tensor([done], dtype=torch.bool, device=self.device)
167-
return TensorDict({"done": done, "next_observation": n}, [])
165+
return TensorDict({"done": done, "observation": n}, [])
168166

169167
def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase:
170168
return self.step(tensordict)
@@ -192,7 +190,7 @@ def __new__(
192190
)
193191
if observation_spec is None:
194192
observation_spec = CompositeSpec(
195-
next_observation=NdUnboundedContinuousTensorSpec((1,))
193+
observation=NdUnboundedContinuousTensorSpec((1,))
196194
)
197195
if reward_spec is None:
198196
reward_spec = NdUnboundedContinuousTensorSpec((1,))
@@ -226,7 +224,7 @@ def _step(self, tensordict):
226224
)
227225

228226
return TensorDict(
229-
{"reward": n, "done": done, "next_observation": n},
227+
{"reward": n, "done": done, "observation": n},
230228
tensordict.batch_size,
231229
device=self.device,
232230
)
@@ -247,7 +245,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
247245
done = torch.full(batch_size, done, dtype=torch.bool, device=self.device)
248246

249247
return TensorDict(
250-
{"reward": n, "done": done, "next_observation": n},
248+
{"reward": n, "done": done, "observation": n},
251249
batch_size,
252250
device=self.device,
253251
)
@@ -287,10 +285,8 @@ def __new__(
287285
if observation_spec is None:
288286
cls.out_key = "observation"
289287
observation_spec = CompositeSpec(
290-
next_observation=NdUnboundedContinuousTensorSpec(
291-
shape=torch.Size([size])
292-
),
293-
next_observation_orig=NdUnboundedContinuousTensorSpec(
288+
observation=NdUnboundedContinuousTensorSpec(shape=torch.Size([size])),
289+
observation_orig=NdUnboundedContinuousTensorSpec(
294290
shape=torch.Size([size])
295291
),
296292
)
@@ -308,7 +304,7 @@ def __new__(
308304
cls._out_key = "observation_orig"
309305
input_spec = CompositeSpec(
310306
**{
311-
cls._out_key: observation_spec["next_observation"],
307+
cls._out_key: observation_spec["observation"],
312308
"action": action_spec,
313309
}
314310
)
@@ -325,15 +321,13 @@ def _get_in_obs(self, obs):
325321
def _get_out_obs(self, obs):
326322
return obs
327323

328-
def _reset(self, tensordict: TensorDictBase) -> TensorDictBase:
324+
def _reset(self, tensordict: TensorDictBase = None) -> TensorDictBase:
329325
self.counter += 1
330326
state = torch.zeros(self.size) + self.counter
331327
if tensordict is None:
332328
tensordict = TensorDict({}, self.batch_size, device=self.device)
333-
tensordict = tensordict.select().set(
334-
"next_" + self.out_key, self._get_out_obs(state)
335-
)
336-
tensordict = tensordict.set("next_" + self._out_key, self._get_out_obs(state))
329+
tensordict = tensordict.select().set(self.out_key, self._get_out_obs(state))
330+
tensordict = tensordict.set(self._out_key, self._get_out_obs(state))
337331
tensordict.set("done", torch.zeros(*tensordict.shape, 1, dtype=torch.bool))
338332
return tensordict
339333

@@ -351,8 +345,8 @@ def _step(
351345
obs = self._get_in_obs(tensordict.get(self._out_key)) + a / self.maxstep
352346
tensordict = tensordict.select() # empty tensordict
353347

354-
tensordict.set("next_" + self.out_key, self._get_out_obs(obs))
355-
tensordict.set("next_" + self._out_key, self._get_out_obs(obs))
348+
tensordict.set(self.out_key, self._get_out_obs(obs))
349+
tensordict.set(self._out_key, self._get_out_obs(obs))
356350

357351
done = torch.isclose(obs, torch.ones_like(obs) * (self.counter + 1))
358352
reward = done.any(-1).unsqueeze(-1)
@@ -379,10 +373,8 @@ def __new__(
379373
if observation_spec is None:
380374
cls.out_key = "observation"
381375
observation_spec = CompositeSpec(
382-
next_observation=NdUnboundedContinuousTensorSpec(
383-
shape=torch.Size([size])
384-
),
385-
next_observation_orig=NdUnboundedContinuousTensorSpec(
376+
observation=NdUnboundedContinuousTensorSpec(shape=torch.Size([size])),
377+
observation_orig=NdUnboundedContinuousTensorSpec(
386378
shape=torch.Size([size])
387379
),
388380
)
@@ -395,7 +387,7 @@ def __new__(
395387
cls._out_key = "observation_orig"
396388
input_spec = CompositeSpec(
397389
**{
398-
cls._out_key: observation_spec["next_observation"],
390+
cls._out_key: observation_spec["observation"],
399391
"action": action_spec,
400392
}
401393
)
@@ -436,8 +428,8 @@ def _step(
436428
obs = self._obs_step(self._get_in_obs(tensordict.get(self._out_key)), a)
437429
tensordict = tensordict.select() # empty tensordict
438430

439-
tensordict.set("next_" + self.out_key, self._get_out_obs(obs))
440-
tensordict.set("next_" + self._out_key, self._get_out_obs(obs))
431+
tensordict.set(self.out_key, self._get_out_obs(obs))
432+
tensordict.set(self._out_key, self._get_out_obs(obs))
441433

442434
done = torch.isclose(obs, torch.ones_like(obs) * (self.counter + 1))
443435
reward = done.any(-1).unsqueeze(-1)
@@ -483,10 +475,8 @@ def __new__(
483475
if observation_spec is None:
484476
cls.out_key = "pixels"
485477
observation_spec = CompositeSpec(
486-
next_pixels=NdUnboundedContinuousTensorSpec(
487-
shape=torch.Size([1, 7, 7])
488-
),
489-
next_pixels_orig=NdUnboundedContinuousTensorSpec(
478+
pixels=NdUnboundedContinuousTensorSpec(shape=torch.Size([1, 7, 7])),
479+
pixels_orig=NdUnboundedContinuousTensorSpec(
490480
shape=torch.Size([1, 7, 7])
491481
),
492482
)
@@ -499,7 +489,7 @@ def __new__(
499489
cls._out_key = "pixels_orig"
500490
input_spec = CompositeSpec(
501491
**{
502-
cls._out_key: observation_spec["next_pixels_orig"],
492+
cls._out_key: observation_spec["pixels_orig"],
503493
"action": action_spec,
504494
}
505495
)
@@ -537,10 +527,8 @@ def __new__(
537527
if observation_spec is None:
538528
cls.out_key = "pixels"
539529
observation_spec = CompositeSpec(
540-
next_pixels=NdUnboundedContinuousTensorSpec(
541-
shape=torch.Size([7, 7, 3])
542-
),
543-
next_pixels_orig=NdUnboundedContinuousTensorSpec(
530+
pixels=NdUnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3])),
531+
pixels_orig=NdUnboundedContinuousTensorSpec(
544532
shape=torch.Size([7, 7, 3])
545533
),
546534
)
@@ -555,7 +543,7 @@ def __new__(
555543
cls._out_key = "pixels_orig"
556544
input_spec = CompositeSpec(
557545
**{
558-
cls._out_key: observation_spec["next_pixels_orig"],
546+
cls._out_key: observation_spec["pixels_orig"],
559547
"action": action_spec,
560548
}
561549
)
@@ -599,10 +587,8 @@ def __new__(
599587
if observation_spec is None:
600588
cls.out_key = "pixels"
601589
observation_spec = CompositeSpec(
602-
next_pixels=NdUnboundedContinuousTensorSpec(
603-
shape=torch.Size(pixel_shape)
604-
),
605-
next_pixels_orig=NdUnboundedContinuousTensorSpec(
590+
pixels=NdUnboundedContinuousTensorSpec(shape=torch.Size(pixel_shape)),
591+
pixels_orig=NdUnboundedContinuousTensorSpec(
606592
shape=torch.Size(pixel_shape)
607593
),
608594
)
@@ -615,7 +601,7 @@ def __new__(
615601
if input_spec is None:
616602
cls._out_key = "pixels_orig"
617603
input_spec = CompositeSpec(
618-
**{cls._out_key: observation_spec["next_pixels"], "action": action_spec}
604+
**{cls._out_key: observation_spec["pixels"], "action": action_spec}
619605
)
620606
return super().__new__(
621607
*args,
@@ -650,10 +636,8 @@ def __new__(
650636
if observation_spec is None:
651637
cls.out_key = "pixels"
652638
observation_spec = CompositeSpec(
653-
next_pixels=NdUnboundedContinuousTensorSpec(
654-
shape=torch.Size([7, 7, 3])
655-
),
656-
next_pixels_orig=NdUnboundedContinuousTensorSpec(
639+
pixels=NdUnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3])),
640+
pixels_orig=NdUnboundedContinuousTensorSpec(
657641
shape=torch.Size([7, 7, 3])
658642
),
659643
)
@@ -714,7 +698,7 @@ def __init__(
714698
batch_size=batch_size,
715699
)
716700
self.observation_spec = CompositeSpec(
717-
next_hidden_observation=NdUnboundedContinuousTensorSpec((4,))
701+
hidden_observation=NdUnboundedContinuousTensorSpec((4,))
718702
)
719703
self.input_spec = CompositeSpec(
720704
hidden_observation=NdUnboundedContinuousTensorSpec((4,)),
@@ -728,9 +712,6 @@ def _reset(self, tensordict: TensorDict, **kwargs) -> TensorDict:
728712
"hidden_observation": self.input_spec["hidden_observation"].rand(
729713
self.batch_size
730714
),
731-
"next_hidden_observation": self.observation_spec[
732-
"next_hidden_observation"
733-
].rand(self.batch_size),
734715
},
735716
batch_size=self.batch_size,
736717
device=self.device,

0 commit comments

Comments
 (0)