2
2
3
3
import torch
4
4
from tensordict .tensordict import TensorDict , TensorDictBase
5
+
5
6
from torchrl .data import CompositeSpec , DEVICE_TYPING , UnboundedContinuousTensorSpec
6
7
from torchrl .envs .common import _EnvWrapper , EnvBase
7
8
from torchrl .envs .libs .gym import _gym_to_torchrl_spec_transform
@@ -107,25 +108,6 @@ def __init__(
107
108
raise TypeError ("Env device is different from vmas device" )
108
109
kwargs ["device" ] = str (env .device )
109
110
super ().__init__ (** kwargs )
110
- if len (self .batch_size ) == 0 :
111
- # Batch size not set
112
- self .batch_size = torch .Size ((self .num_envs ,))
113
- elif len (self .batch_size ) == 1 :
114
- # Batch size is set
115
- if not self .batch_size [0 ] == self .num_envs :
116
- raise TypeError (
117
- "Batch size used in constructor does not match vmas batch size."
118
- )
119
- else :
120
- raise TypeError (
121
- "Batch size used in constructor is not compatible with vmas."
122
- )
123
- self .batch_size = torch .Size ([self .n_agents , * self .batch_size ])
124
- self .input_spec = self .input_spec .expand (self .batch_size )
125
- self .observation_spec = self .observation_spec .expand (self .batch_size )
126
- self .reward_spec = self .reward_spec .expand (
127
- [* self .batch_size , * self .reward_spec .shape ]
128
- )
129
111
130
112
@property
131
113
def lib (self ):
@@ -144,6 +126,22 @@ def _build_env(
144
126
if self .from_pixels :
145
127
raise NotImplementedError ("vmas rendering not yet implemented" )
146
128
129
+ # Adjust batch size
130
+ if len (self .batch_size ) == 0 :
131
+ # Batch size not set
132
+ self .batch_size = torch .Size ((env .num_envs ,))
133
+ elif len (self .batch_size ) == 1 :
134
+ # Batch size is set
135
+ if not self .batch_size [0 ] == env .num_envs :
136
+ raise TypeError (
137
+ "Batch size used in constructor does not match vmas batch size."
138
+ )
139
+ else :
140
+ raise TypeError (
141
+ "Batch size used in constructor is not compatible with vmas."
142
+ )
143
+ self .batch_size = torch .Size ([env .n_agents , * self .batch_size ])
144
+
147
145
return env
148
146
149
147
def _make_specs (
0 commit comments