|
33 | 33 | # Libraries necessary for MultiThreadedEnv
|
34 | 34 | import envpool
|
35 | 35 |
|
36 |
| - try: |
37 |
| - import gym |
38 |
| - except ModuleNotFoundError: |
39 |
| - import gymnasium as gym |
40 |
| - |
41 | 36 | import treevalue
|
42 | 37 |
|
43 | 38 | _has_envpool = True
|
@@ -1181,66 +1176,36 @@ def _get_input_spec(self) -> TensorSpec:
|
1181 | 1176 | # DM_Control-compatible specs as env.spec.action_spec(). We use the Gym ones.
|
1182 | 1177 |
|
1183 | 1178 | # Gym specs produced by EnvPool don't contain batch_size, we add it to satisfy checks in EnvBase
|
1184 |
| - action_spec = self._add_shape_to_spec(self._env.spec.action_space) |
1185 |
| - transformed_spec = _gym_to_torchrl_spec_transform( |
1186 |
| - action_spec, |
| 1179 | + action_spec = _gym_to_torchrl_spec_transform( |
| 1180 | + self._env.spec.action_space, |
1187 | 1181 | device=self.device,
|
1188 | 1182 | categorical_action_encoding=True,
|
1189 | 1183 | )
|
1190 |
| - if not transformed_spec.shape: |
1191 |
| - transformed_spec.shape = (self.num_workers,) |
| 1184 | + action_spec = self._add_shape_to_spec(action_spec) |
1192 | 1185 | return CompositeSpec(
|
1193 |
| - action=transformed_spec, |
1194 |
| - shape=transformed_spec.shape, |
| 1186 | + action=action_spec, |
| 1187 | + shape=(self.num_workers,), |
1195 | 1188 | )
|
1196 | 1189 |
|
1197 | 1190 | def _get_observation_spec(self) -> TensorSpec:
|
1198 | 1191 | # Gym specs produced by EnvPool don't contain batch_size, we add it to satisfy checks in EnvBase
|
1199 |
| - obs_spec = self._add_shape_to_spec(self._env.spec.observation_space) |
1200 | 1192 | observation_spec = _gym_to_torchrl_spec_transform(
|
1201 |
| - obs_spec, |
| 1193 | + self._env.spec.observation_space, |
1202 | 1194 | device=self.device,
|
1203 | 1195 | categorical_action_encoding=True,
|
1204 | 1196 | )
|
1205 |
| - if isinstance(observation_spec, CompositeSpec): |
1206 |
| - observation_spec.shape = (self.num_workers,) |
| 1197 | + observation_spec = self._add_shape_to_spec(observation_spec) |
1207 | 1198 | return CompositeSpec(
|
1208 | 1199 | observation=observation_spec,
|
1209 |
| - shape=observation_spec.shape, |
| 1200 | + shape=(self.num_workers,), |
1210 | 1201 | )
|
1211 | 1202 |
|
1212 |
| - def _add_shape_to_spec( |
1213 |
| - self, spec: gym.spaces.space.Space |
1214 |
| - ) -> gym.spaces.space.Space: |
1215 |
| - if isinstance(spec, gym.spaces.Box): |
1216 |
| - return gym.spaces.Box( |
1217 |
| - low=np.stack([spec.low] * self.num_workers), |
1218 |
| - high=np.stack([spec.high] * self.num_workers), |
1219 |
| - dtype=spec.dtype, |
1220 |
| - shape=(self.num_workers, *spec.shape), |
1221 |
| - ) |
1222 |
| - if isinstance(spec, gym.spaces.dict.Dict): |
1223 |
| - spec_dict = {} |
1224 |
| - for key in spec.keys(): |
1225 |
| - if isinstance(spec[key], gym.spaces.Box): |
1226 |
| - spec_dict[key] = gym.spaces.Box( |
1227 |
| - low=np.stack([spec[key].low] * self.num_workers), |
1228 |
| - high=np.stack([spec[key].high] * self.num_workers), |
1229 |
| - dtype=spec[key].dtype, |
1230 |
| - shape=(self.num_workers, *spec[key].shape), |
1231 |
| - ) |
1232 |
| - elif isinstance(spec[key], gym.spaces.dict.Dict): |
1233 |
| - # If needed, we could add support by applying this function recursively |
1234 |
| - raise TypeError("Nested specs with depth > 1 are not supported.") |
1235 |
| - return spec_dict |
1236 |
| - if isinstance(spec, gym.spaces.discrete.Discrete): |
1237 |
| - # Discrete spec in Gym doesn't have shape, so nothing to change |
1238 |
| - return spec |
1239 |
| - raise TypeError(f"Unsupported spec type {spec.__class__}.") |
| 1203 | + def _add_shape_to_spec(self, spec: TensorSpec) -> TensorSpec: |
| 1204 | + return spec.expand((self.num_workers, *spec.shape)) |
1240 | 1205 |
|
1241 | 1206 | def _get_reward_spec(self) -> TensorSpec:
|
1242 | 1207 | return UnboundedContinuousTensorSpec(
|
1243 |
| - device=self.device, shape=(self.num_workers,) |
| 1208 | + device=self.device, shape=(self.num_workers, 1) |
1244 | 1209 | )
|
1245 | 1210 |
|
1246 | 1211 | def __repr__(self) -> str:
|
|
0 commit comments