|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +import warnings |
3 | 4 | from typing import Optional, Union, Tuple
|
4 | 5 |
|
5 | 6 | import numpy as np
|
@@ -38,6 +39,11 @@ def __init__(self, keys=None):
|
38 | 39 | self.keys = keys
|
39 | 40 |
|
40 | 41 | def __call__(self, info_dict: dict, tensordict: _TensorDict) -> _TensorDict:
|
| 42 | + if not isinstance(info_dict, dict) and len(self.keys): |
| 43 | + warnings.warn( |
| 44 | + f"Found an info_dict of type {type(info_dict)} " |
| 45 | + f"but expected type or subtype `dict`." |
| 46 | + ) |
41 | 47 | for key in self.keys:
|
42 | 48 | if key in info_dict:
|
43 | 49 | tensordict[key] = info_dict[key]
|
@@ -67,6 +73,11 @@ class GymLikeEnv(_EnvWrapper):
|
67 | 73 | It is also expected that env.reset() returns an observation similar to the one observed after a step is completed.
|
68 | 74 | """
|
69 | 75 |
|
| 76 | + @classmethod |
| 77 | + def __new__(cls, *args, **kwargs): |
| 78 | + cls._info_dict_reader = None |
| 79 | + return super().__new__(cls, *args, **kwargs) |
| 80 | + |
70 | 81 | def _step(self, tensordict: _TensorDict) -> _TensorDict:
|
71 | 82 | action = tensordict.get("action")
|
72 | 83 | action_np = self.action_spec.to_numpy(action, safe=False)
|
@@ -98,7 +109,8 @@ def _step(self, tensordict: _TensorDict) -> _TensorDict:
|
98 | 109 | )
|
99 | 110 | tensordict_out.set("reward", reward)
|
100 | 111 | tensordict_out.set("done", done)
|
101 |
| - self.info_dict_reader(info, tensordict_out) |
| 112 | + if self.info_dict_reader is not None: |
| 113 | + self.info_dict_reader(*info, tensordict_out) |
102 | 114 |
|
103 | 115 | return tensordict_out
|
104 | 116 |
|
@@ -156,17 +168,15 @@ def set_info_dict_reader(self, info_dict_reader: callable) -> GymLikeEnv:
|
156 | 168 | self.info_dict_reader = info_dict_reader
|
157 | 169 | return self
|
158 | 170 |
|
| 171 | + def __repr__(self) -> str: |
| 172 | + return ( |
| 173 | + f"{self.__class__.__name__}(env={self._env}, batch_size={self.batch_size})" |
| 174 | + ) |
| 175 | + |
159 | 176 | @property
|
160 | 177 | def info_dict_reader(self):
|
161 |
| - if "_info_dict_reader" not in self.__dir__(): |
162 |
| - self._info_dict_reader = default_info_dict_reader() |
163 | 178 | return self._info_dict_reader
|
164 | 179 |
|
165 | 180 | @info_dict_reader.setter
|
166 | 181 | def info_dict_reader(self, value: callable):
|
167 | 182 | self._info_dict_reader = value
|
168 |
| - |
169 |
| - def __repr__(self) -> str: |
170 |
| - return ( |
171 |
| - f"{self.__class__.__name__}(env={self._env}, batch_size={self.batch_size})" |
172 |
| - ) |
|
0 commit comments