@@ -129,9 +129,14 @@ class DiscreteActionVecMockEnv(_MockEnv):
129
129
)
130
130
action_spec = OneHotDiscreteTensorSpec (7 )
131
131
reward_spec = UnboundedContinuousTensorSpec ()
132
+
132
133
from_pixels = False
133
134
134
135
out_key = "observation"
136
+ _out_key = "observation_orig"
137
+ input_spec = CompositeSpec (
138
+ ** {_out_key : observation_spec ["next_observation" ], "action" : action_spec }
139
+ )
135
140
136
141
def _get_in_obs (self , obs ):
137
142
return obs
@@ -145,6 +150,7 @@ def _reset(self, tensordict: _TensorDict) -> _TensorDict:
145
150
tensordict = tensordict .select ().set (
146
151
"next_" + self .out_key , self ._get_out_obs (state )
147
152
)
153
+ tensordict = tensordict .set ("next_" + self ._out_key , self ._get_out_obs (state ))
148
154
tensordict .set ("done" , torch .zeros (* tensordict .shape , 1 , dtype = torch .bool ))
149
155
return tensordict
150
156
@@ -157,12 +163,12 @@ def _step(
157
163
assert (a .sum (- 1 ) == 1 ).all ()
158
164
assert not self .is_done , "trying to execute step in done env"
159
165
160
- obs = (
161
- self ._get_in_obs (self .current_tensordict .get (self .out_key ))
162
- + a / self .maxstep
163
- )
166
+ obs = self ._get_in_obs (tensordict .get (self ._out_key )) + a / self .maxstep
164
167
tensordict = tensordict .select () # empty tensordict
168
+
165
169
tensordict .set ("next_" + self .out_key , self ._get_out_obs (obs ))
170
+ tensordict .set ("next_" + self ._out_key , self ._get_out_obs (obs ))
171
+
166
172
done = torch .isclose (obs , torch .ones_like (obs ) * (self .counter + 1 ))
167
173
reward = done .any (- 1 ).unsqueeze (- 1 )
168
174
# set done to False
@@ -182,6 +188,10 @@ class ContinuousActionVecMockEnv(_MockEnv):
182
188
from_pixels = False
183
189
184
190
out_key = "observation"
191
+ _out_key = "observation_orig"
192
+ input_spec = CompositeSpec (
193
+ ** {_out_key : observation_spec ["next_observation" ], "action" : action_spec }
194
+ )
185
195
186
196
def _get_in_obs (self , obs ):
187
197
return obs
@@ -193,9 +203,9 @@ def _reset(self, tensordict: _TensorDict) -> _TensorDict:
193
203
self .counter += 1
194
204
self .step_count = 0
195
205
state = torch .zeros (self .size ) + self .counter
196
- tensordict = tensordict .select (). set (
197
- "next_" + self .out_key , self ._get_out_obs (state )
198
- )
206
+ tensordict = tensordict .select ()
207
+ tensordict . set ( "next_" + self .out_key , self ._get_out_obs (state ) )
208
+ tensordict . set ( "next_" + self . _out_key , self . _get_out_obs ( state ) )
199
209
tensordict .set ("done" , torch .zeros (* tensordict .shape , 1 , dtype = torch .bool ))
200
210
return tensordict
201
211
@@ -208,11 +218,12 @@ def _step(
208
218
a = tensordict .get ("action" )
209
219
assert not self .is_done , "trying to execute step in done env"
210
220
211
- obs = self ._obs_step (
212
- self ._get_in_obs (self .current_tensordict .get (self .out_key )), a
213
- )
221
+ obs = self ._obs_step (self ._get_in_obs (tensordict .get (self ._out_key )), a )
214
222
tensordict = tensordict .select () # empty tensordict
223
+
215
224
tensordict .set ("next_" + self .out_key , self ._get_out_obs (obs ))
225
+ tensordict .set ("next_" + self ._out_key , self ._get_out_obs (obs ))
226
+
216
227
done = torch .isclose (obs , torch .ones_like (obs ) * (self .counter + 1 ))
217
228
reward = done .any (- 1 ).unsqueeze (- 1 )
218
229
done = done .all (- 1 ).unsqueeze (- 1 )
@@ -251,6 +262,10 @@ class DiscreteActionConvMockEnv(DiscreteActionVecMockEnv):
251
262
from_pixels = True
252
263
253
264
out_key = "pixels"
265
+ _out_key = "pixels_orig"
266
+ input_spec = CompositeSpec (
267
+ ** {_out_key : observation_spec ["next_pixels" ], "action" : action_spec }
268
+ )
254
269
255
270
def _get_out_obs (self , obs ):
256
271
obs = torch .diag_embed (obs , 0 , - 2 , - 1 ).unsqueeze (0 )
@@ -287,6 +302,10 @@ class ContinuousActionConvMockEnv(ContinuousActionVecMockEnv):
287
302
from_pixels = True
288
303
289
304
out_key = "pixels"
305
+ _out_key = "pixels_orig"
306
+ input_spec = CompositeSpec (
307
+ ** {_out_key : observation_spec ["next_pixels" ], "action" : action_spec }
308
+ )
290
309
291
310
def _get_out_obs (self , obs ):
292
311
obs = torch .diag_embed (obs , 0 , - 2 , - 1 ).unsqueeze (0 )
0 commit comments