Skip to content

Commit 5bd4b4f

Browse files
authored
[BugFix] Fix EnvPool spec shapes (#932)
1 parent e2d5dbe commit 5bd4b4f

File tree

1 file changed

+11
-46
lines changed

1 file changed

+11
-46
lines changed

torchrl/envs/vec_env.py

Lines changed: 11 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,6 @@
3333
# Libraries necessary for MultiThreadedEnv
3434
import envpool
3535

36-
try:
37-
import gym
38-
except ModuleNotFoundError:
39-
import gymnasium as gym
40-
4136
import treevalue
4237

4338
_has_envpool = True
@@ -1181,66 +1176,36 @@ def _get_input_spec(self) -> TensorSpec:
11811176
# DM_Control-compatible specs as env.spec.action_spec(). We use the Gym ones.
11821177

11831178
# Gym specs produced by EnvPool don't contain batch_size, we add it to satisfy checks in EnvBase
1184-
action_spec = self._add_shape_to_spec(self._env.spec.action_space)
1185-
transformed_spec = _gym_to_torchrl_spec_transform(
1186-
action_spec,
1179+
action_spec = _gym_to_torchrl_spec_transform(
1180+
self._env.spec.action_space,
11871181
device=self.device,
11881182
categorical_action_encoding=True,
11891183
)
1190-
if not transformed_spec.shape:
1191-
transformed_spec.shape = (self.num_workers,)
1184+
action_spec = self._add_shape_to_spec(action_spec)
11921185
return CompositeSpec(
1193-
action=transformed_spec,
1194-
shape=transformed_spec.shape,
1186+
action=action_spec,
1187+
shape=(self.num_workers,),
11951188
)
11961189

11971190
def _get_observation_spec(self) -> TensorSpec:
11981191
# Gym specs produced by EnvPool don't contain batch_size, we add it to satisfy checks in EnvBase
1199-
obs_spec = self._add_shape_to_spec(self._env.spec.observation_space)
12001192
observation_spec = _gym_to_torchrl_spec_transform(
1201-
obs_spec,
1193+
self._env.spec.observation_space,
12021194
device=self.device,
12031195
categorical_action_encoding=True,
12041196
)
1205-
if isinstance(observation_spec, CompositeSpec):
1206-
observation_spec.shape = (self.num_workers,)
1197+
observation_spec = self._add_shape_to_spec(observation_spec)
12071198
return CompositeSpec(
12081199
observation=observation_spec,
1209-
shape=observation_spec.shape,
1200+
shape=(self.num_workers,),
12101201
)
12111202

1212-
def _add_shape_to_spec(
1213-
self, spec: gym.spaces.space.Space
1214-
) -> gym.spaces.space.Space:
1215-
if isinstance(spec, gym.spaces.Box):
1216-
return gym.spaces.Box(
1217-
low=np.stack([spec.low] * self.num_workers),
1218-
high=np.stack([spec.high] * self.num_workers),
1219-
dtype=spec.dtype,
1220-
shape=(self.num_workers, *spec.shape),
1221-
)
1222-
if isinstance(spec, gym.spaces.dict.Dict):
1223-
spec_dict = {}
1224-
for key in spec.keys():
1225-
if isinstance(spec[key], gym.spaces.Box):
1226-
spec_dict[key] = gym.spaces.Box(
1227-
low=np.stack([spec[key].low] * self.num_workers),
1228-
high=np.stack([spec[key].high] * self.num_workers),
1229-
dtype=spec[key].dtype,
1230-
shape=(self.num_workers, *spec[key].shape),
1231-
)
1232-
elif isinstance(spec[key], gym.spaces.dict.Dict):
1233-
# If needed, we could add support by applying this function recursively
1234-
raise TypeError("Nested specs with depth > 1 are not supported.")
1235-
return spec_dict
1236-
if isinstance(spec, gym.spaces.discrete.Discrete):
1237-
# Discrete spec in Gym doesn't have shape, so nothing to change
1238-
return spec
1239-
raise TypeError(f"Unsupported spec type {spec.__class__}.")
1203+
def _add_shape_to_spec(self, spec: TensorSpec) -> TensorSpec:
1204+
return spec.expand((self.num_workers, *spec.shape))
12401205

12411206
def _get_reward_spec(self) -> TensorSpec:
12421207
return UnboundedContinuousTensorSpec(
1243-
device=self.device, shape=(self.num_workers,)
1208+
device=self.device, shape=(self.num_workers, 1)
12441209
)
12451210

12461211
def __repr__(self) -> str:

0 commit comments

Comments
 (0)