Skip to content

Commit 732e3a2

Browse files
[BugFix] Vmas expanded specs (#942)
1 parent c0e8a1c commit 732e3a2

File tree

1 file changed

+17
-19
lines changed

1 file changed

+17
-19
lines changed

torchrl/envs/libs/vmas.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
from tensordict.tensordict import TensorDict, TensorDictBase
5+
56
from torchrl.data import CompositeSpec, DEVICE_TYPING, UnboundedContinuousTensorSpec
67
from torchrl.envs.common import _EnvWrapper, EnvBase
78
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
@@ -107,25 +108,6 @@ def __init__(
107108
raise TypeError("Env device is different from vmas device")
108109
kwargs["device"] = str(env.device)
109110
super().__init__(**kwargs)
110-
if len(self.batch_size) == 0:
111-
# Batch size not set
112-
self.batch_size = torch.Size((self.num_envs,))
113-
elif len(self.batch_size) == 1:
114-
# Batch size is set
115-
if not self.batch_size[0] == self.num_envs:
116-
raise TypeError(
117-
"Batch size used in constructor does not match vmas batch size."
118-
)
119-
else:
120-
raise TypeError(
121-
"Batch size used in constructor is not compatible with vmas."
122-
)
123-
self.batch_size = torch.Size([self.n_agents, *self.batch_size])
124-
self.input_spec = self.input_spec.expand(self.batch_size)
125-
self.observation_spec = self.observation_spec.expand(self.batch_size)
126-
self.reward_spec = self.reward_spec.expand(
127-
[*self.batch_size, *self.reward_spec.shape]
128-
)
129111

130112
@property
131113
def lib(self):
@@ -144,6 +126,22 @@ def _build_env(
144126
if self.from_pixels:
145127
raise NotImplementedError("vmas rendering not yet implemented")
146128

129+
# Adjust batch size
130+
if len(self.batch_size) == 0:
131+
# Batch size not set
132+
self.batch_size = torch.Size((env.num_envs,))
133+
elif len(self.batch_size) == 1:
134+
# Batch size is set
135+
if not self.batch_size[0] == env.num_envs:
136+
raise TypeError(
137+
"Batch size used in constructor does not match vmas batch size."
138+
)
139+
else:
140+
raise TypeError(
141+
"Batch size used in constructor is not compatible with vmas."
142+
)
143+
self.batch_size = torch.Size([env.n_agents, *self.batch_size])
144+
147145
return env
148146

149147
def _make_specs(

0 commit comments

Comments
 (0)