Skip to content

Commit 225f92f

Browse files
authored
Bugfix: Double2Float default behaviour (#242)
1 parent 5ed32a9 commit 225f92f

File tree

3 files changed

+36
-11
lines changed

3 files changed

+36
-11
lines changed

torchrl/envs/gym_like.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import warnings
34
from typing import Optional, Union, Tuple
45

56
import numpy as np
@@ -38,6 +39,11 @@ def __init__(self, keys=None):
3839
self.keys = keys
3940

4041
def __call__(self, info_dict: dict, tensordict: _TensorDict) -> _TensorDict:
42+
if not isinstance(info_dict, dict) and len(self.keys):
43+
warnings.warn(
44+
f"Found an info_dict of type {type(info_dict)} "
45+
f"but expected type or subtype `dict`."
46+
)
4147
for key in self.keys:
4248
if key in info_dict:
4349
tensordict[key] = info_dict[key]
@@ -67,6 +73,11 @@ class GymLikeEnv(_EnvWrapper):
6773
It is also expected that env.reset() returns an observation similar to the one observed after a step is completed.
6874
"""
6975

76+
@classmethod
77+
def __new__(cls, *args, **kwargs):
78+
cls._info_dict_reader = None
79+
return super().__new__(cls, *args, **kwargs)
80+
7081
def _step(self, tensordict: _TensorDict) -> _TensorDict:
7182
action = tensordict.get("action")
7283
action_np = self.action_spec.to_numpy(action, safe=False)
@@ -98,7 +109,8 @@ def _step(self, tensordict: _TensorDict) -> _TensorDict:
98109
)
99110
tensordict_out.set("reward", reward)
100111
tensordict_out.set("done", done)
101-
self.info_dict_reader(info, tensordict_out)
112+
if self.info_dict_reader is not None:
113+
self.info_dict_reader(*info, tensordict_out)
102114

103115
return tensordict_out
104116

@@ -156,17 +168,15 @@ def set_info_dict_reader(self, info_dict_reader: callable) -> GymLikeEnv:
156168
self.info_dict_reader = info_dict_reader
157169
return self
158170

171+
def __repr__(self) -> str:
172+
return (
173+
f"{self.__class__.__name__}(env={self._env}, batch_size={self.batch_size})"
174+
)
175+
159176
@property
160177
def info_dict_reader(self):
161-
if "_info_dict_reader" not in self.__dir__():
162-
self._info_dict_reader = default_info_dict_reader()
163178
return self._info_dict_reader
164179

165180
@info_dict_reader.setter
166181
def info_dict_reader(self, value: callable):
167182
self._info_dict_reader = value
168-
169-
def __repr__(self) -> str:
170-
return (
171-
f"{self.__class__.__name__}(env={self._env}, batch_size={self.batch_size})"
172-
)

torchrl/envs/libs/gym.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
UnboundedContinuousTensorSpec,
2020
)
2121
from ...data.utils import numpy_to_torch_dtype_dict
22-
from ..gym_like import GymLikeEnv
22+
from ..gym_like import GymLikeEnv, default_info_dict_reader
2323
from ..utils import classproperty
2424

2525
try:
@@ -226,6 +226,16 @@ def rebuild_with_kwargs(self, **new_kwargs):
226226
self._env = self._build_env(**self._constructor_kwargs)
227227
self._make_specs(self._env)
228228

229+
@property
230+
def info_dict_reader(self):
231+
if self._info_dict_reader is None:
232+
self._info_dict_reader = default_info_dict_reader()
233+
return self._info_dict_reader
234+
235+
@info_dict_reader.setter
236+
def info_dict_reader(self, value: callable):
237+
self._info_dict_reader = value
238+
229239

230240
class GymEnv(GymWrapper):
231241
"""

torchrl/envs/transforms/transforms.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,8 +1245,6 @@ def __init__(
12451245
keys_in: Optional[Sequence[str]] = None,
12461246
keys_inv_in: Optional[Sequence[str]] = None,
12471247
):
1248-
if keys_inv_in is None:
1249-
keys_inv_in = ["action"]
12501248
super().__init__(keys_in=keys_in, keys_inv_in=keys_inv_in)
12511249

12521250
def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor:
@@ -1286,6 +1284,13 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
12861284
self._transform_spec(observation_spec)
12871285
return observation_spec
12881286

1287+
def __repr__(self) -> str:
1288+
s = (
1289+
f"{self.__class__.__name__}(keys_in={self.keys_in}, keys_out={self.keys_out},"
1290+
f"keys_inv_in={self.keys_inv_in}, keys_inv_out={self.keys_inv_out})"
1291+
)
1292+
return s
1293+
12891294

12901295
class CatTensors(Transform):
12911296
"""

0 commit comments

Comments
 (0)