@@ -121,7 +121,7 @@ def __new__(
121
121
action_spec = NdUnboundedContinuousTensorSpec ((1 ,))
122
122
if observation_spec is None :
123
123
observation_spec = CompositeSpec (
124
- next_observation = NdUnboundedContinuousTensorSpec ((1 ,))
124
+ observation = NdUnboundedContinuousTensorSpec ((1 ,))
125
125
)
126
126
if reward_spec is None :
127
127
reward_spec = NdUnboundedContinuousTensorSpec ((1 ,))
@@ -152,19 +152,17 @@ def _step(self, tensordict):
152
152
)
153
153
done = self .counter >= self .max_val
154
154
done = torch .tensor ([done ], dtype = torch .bool , device = self .device )
155
- return TensorDict (
156
- {"reward" : n , "done" : done , "next_observation" : n .clone ()}, []
157
- )
155
+ return TensorDict ({"reward" : n , "done" : done , "observation" : n .clone ()}, [])
158
156
159
- def _reset (self , tensordict : TensorDictBase , ** kwargs ) -> TensorDictBase :
157
+ def _reset (self , tensordict : TensorDictBase = None , ** kwargs ) -> TensorDictBase :
160
158
self .max_val = max (self .counter + 100 , self .counter * 2 )
161
159
162
160
n = torch .tensor (
163
161
[self .counter ], device = self .device , dtype = torch .get_default_dtype ()
164
162
)
165
163
done = self .counter >= self .max_val
166
164
done = torch .tensor ([done ], dtype = torch .bool , device = self .device )
167
- return TensorDict ({"done" : done , "next_observation " : n }, [])
165
+ return TensorDict ({"done" : done , "observation " : n }, [])
168
166
169
167
def rand_step (self , tensordict : Optional [TensorDictBase ] = None ) -> TensorDictBase :
170
168
return self .step (tensordict )
@@ -192,7 +190,7 @@ def __new__(
192
190
)
193
191
if observation_spec is None :
194
192
observation_spec = CompositeSpec (
195
- next_observation = NdUnboundedContinuousTensorSpec ((1 ,))
193
+ observation = NdUnboundedContinuousTensorSpec ((1 ,))
196
194
)
197
195
if reward_spec is None :
198
196
reward_spec = NdUnboundedContinuousTensorSpec ((1 ,))
@@ -226,7 +224,7 @@ def _step(self, tensordict):
226
224
)
227
225
228
226
return TensorDict (
229
- {"reward" : n , "done" : done , "next_observation " : n },
227
+ {"reward" : n , "done" : done , "observation " : n },
230
228
tensordict .batch_size ,
231
229
device = self .device ,
232
230
)
@@ -247,7 +245,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
247
245
done = torch .full (batch_size , done , dtype = torch .bool , device = self .device )
248
246
249
247
return TensorDict (
250
- {"reward" : n , "done" : done , "next_observation " : n },
248
+ {"reward" : n , "done" : done , "observation " : n },
251
249
batch_size ,
252
250
device = self .device ,
253
251
)
@@ -287,10 +285,8 @@ def __new__(
287
285
if observation_spec is None :
288
286
cls .out_key = "observation"
289
287
observation_spec = CompositeSpec (
290
- next_observation = NdUnboundedContinuousTensorSpec (
291
- shape = torch .Size ([size ])
292
- ),
293
- next_observation_orig = NdUnboundedContinuousTensorSpec (
288
+ observation = NdUnboundedContinuousTensorSpec (shape = torch .Size ([size ])),
289
+ observation_orig = NdUnboundedContinuousTensorSpec (
294
290
shape = torch .Size ([size ])
295
291
),
296
292
)
@@ -308,7 +304,7 @@ def __new__(
308
304
cls ._out_key = "observation_orig"
309
305
input_spec = CompositeSpec (
310
306
** {
311
- cls ._out_key : observation_spec ["next_observation " ],
307
+ cls ._out_key : observation_spec ["observation " ],
312
308
"action" : action_spec ,
313
309
}
314
310
)
@@ -325,15 +321,13 @@ def _get_in_obs(self, obs):
325
321
def _get_out_obs (self , obs ):
326
322
return obs
327
323
328
- def _reset (self , tensordict : TensorDictBase ) -> TensorDictBase :
324
+ def _reset (self , tensordict : TensorDictBase = None ) -> TensorDictBase :
329
325
self .counter += 1
330
326
state = torch .zeros (self .size ) + self .counter
331
327
if tensordict is None :
332
328
tensordict = TensorDict ({}, self .batch_size , device = self .device )
333
- tensordict = tensordict .select ().set (
334
- "next_" + self .out_key , self ._get_out_obs (state )
335
- )
336
- tensordict = tensordict .set ("next_" + self ._out_key , self ._get_out_obs (state ))
329
+ tensordict = tensordict .select ().set (self .out_key , self ._get_out_obs (state ))
330
+ tensordict = tensordict .set (self ._out_key , self ._get_out_obs (state ))
337
331
tensordict .set ("done" , torch .zeros (* tensordict .shape , 1 , dtype = torch .bool ))
338
332
return tensordict
339
333
@@ -351,8 +345,8 @@ def _step(
351
345
obs = self ._get_in_obs (tensordict .get (self ._out_key )) + a / self .maxstep
352
346
tensordict = tensordict .select () # empty tensordict
353
347
354
- tensordict .set ("next_" + self .out_key , self ._get_out_obs (obs ))
355
- tensordict .set ("next_" + self ._out_key , self ._get_out_obs (obs ))
348
+ tensordict .set (self .out_key , self ._get_out_obs (obs ))
349
+ tensordict .set (self ._out_key , self ._get_out_obs (obs ))
356
350
357
351
done = torch .isclose (obs , torch .ones_like (obs ) * (self .counter + 1 ))
358
352
reward = done .any (- 1 ).unsqueeze (- 1 )
@@ -379,10 +373,8 @@ def __new__(
379
373
if observation_spec is None :
380
374
cls .out_key = "observation"
381
375
observation_spec = CompositeSpec (
382
- next_observation = NdUnboundedContinuousTensorSpec (
383
- shape = torch .Size ([size ])
384
- ),
385
- next_observation_orig = NdUnboundedContinuousTensorSpec (
376
+ observation = NdUnboundedContinuousTensorSpec (shape = torch .Size ([size ])),
377
+ observation_orig = NdUnboundedContinuousTensorSpec (
386
378
shape = torch .Size ([size ])
387
379
),
388
380
)
@@ -395,7 +387,7 @@ def __new__(
395
387
cls ._out_key = "observation_orig"
396
388
input_spec = CompositeSpec (
397
389
** {
398
- cls ._out_key : observation_spec ["next_observation " ],
390
+ cls ._out_key : observation_spec ["observation " ],
399
391
"action" : action_spec ,
400
392
}
401
393
)
@@ -436,8 +428,8 @@ def _step(
436
428
obs = self ._obs_step (self ._get_in_obs (tensordict .get (self ._out_key )), a )
437
429
tensordict = tensordict .select () # empty tensordict
438
430
439
- tensordict .set ("next_" + self .out_key , self ._get_out_obs (obs ))
440
- tensordict .set ("next_" + self ._out_key , self ._get_out_obs (obs ))
431
+ tensordict .set (self .out_key , self ._get_out_obs (obs ))
432
+ tensordict .set (self ._out_key , self ._get_out_obs (obs ))
441
433
442
434
done = torch .isclose (obs , torch .ones_like (obs ) * (self .counter + 1 ))
443
435
reward = done .any (- 1 ).unsqueeze (- 1 )
@@ -483,10 +475,8 @@ def __new__(
483
475
if observation_spec is None :
484
476
cls .out_key = "pixels"
485
477
observation_spec = CompositeSpec (
486
- next_pixels = NdUnboundedContinuousTensorSpec (
487
- shape = torch .Size ([1 , 7 , 7 ])
488
- ),
489
- next_pixels_orig = NdUnboundedContinuousTensorSpec (
478
+ pixels = NdUnboundedContinuousTensorSpec (shape = torch .Size ([1 , 7 , 7 ])),
479
+ pixels_orig = NdUnboundedContinuousTensorSpec (
490
480
shape = torch .Size ([1 , 7 , 7 ])
491
481
),
492
482
)
@@ -499,7 +489,7 @@ def __new__(
499
489
cls ._out_key = "pixels_orig"
500
490
input_spec = CompositeSpec (
501
491
** {
502
- cls ._out_key : observation_spec ["next_pixels_orig " ],
492
+ cls ._out_key : observation_spec ["pixels_orig " ],
503
493
"action" : action_spec ,
504
494
}
505
495
)
@@ -537,10 +527,8 @@ def __new__(
537
527
if observation_spec is None :
538
528
cls .out_key = "pixels"
539
529
observation_spec = CompositeSpec (
540
- next_pixels = NdUnboundedContinuousTensorSpec (
541
- shape = torch .Size ([7 , 7 , 3 ])
542
- ),
543
- next_pixels_orig = NdUnboundedContinuousTensorSpec (
530
+ pixels = NdUnboundedContinuousTensorSpec (shape = torch .Size ([7 , 7 , 3 ])),
531
+ pixels_orig = NdUnboundedContinuousTensorSpec (
544
532
shape = torch .Size ([7 , 7 , 3 ])
545
533
),
546
534
)
@@ -555,7 +543,7 @@ def __new__(
555
543
cls ._out_key = "pixels_orig"
556
544
input_spec = CompositeSpec (
557
545
** {
558
- cls ._out_key : observation_spec ["next_pixels_orig " ],
546
+ cls ._out_key : observation_spec ["pixels_orig " ],
559
547
"action" : action_spec ,
560
548
}
561
549
)
@@ -599,10 +587,8 @@ def __new__(
599
587
if observation_spec is None :
600
588
cls .out_key = "pixels"
601
589
observation_spec = CompositeSpec (
602
- next_pixels = NdUnboundedContinuousTensorSpec (
603
- shape = torch .Size (pixel_shape )
604
- ),
605
- next_pixels_orig = NdUnboundedContinuousTensorSpec (
590
+ pixels = NdUnboundedContinuousTensorSpec (shape = torch .Size (pixel_shape )),
591
+ pixels_orig = NdUnboundedContinuousTensorSpec (
606
592
shape = torch .Size (pixel_shape )
607
593
),
608
594
)
@@ -615,7 +601,7 @@ def __new__(
615
601
if input_spec is None :
616
602
cls ._out_key = "pixels_orig"
617
603
input_spec = CompositeSpec (
618
- ** {cls ._out_key : observation_spec ["next_pixels " ], "action" : action_spec }
604
+ ** {cls ._out_key : observation_spec ["pixels " ], "action" : action_spec }
619
605
)
620
606
return super ().__new__ (
621
607
* args ,
@@ -650,10 +636,8 @@ def __new__(
650
636
if observation_spec is None :
651
637
cls .out_key = "pixels"
652
638
observation_spec = CompositeSpec (
653
- next_pixels = NdUnboundedContinuousTensorSpec (
654
- shape = torch .Size ([7 , 7 , 3 ])
655
- ),
656
- next_pixels_orig = NdUnboundedContinuousTensorSpec (
639
+ pixels = NdUnboundedContinuousTensorSpec (shape = torch .Size ([7 , 7 , 3 ])),
640
+ pixels_orig = NdUnboundedContinuousTensorSpec (
657
641
shape = torch .Size ([7 , 7 , 3 ])
658
642
),
659
643
)
@@ -714,7 +698,7 @@ def __init__(
714
698
batch_size = batch_size ,
715
699
)
716
700
self .observation_spec = CompositeSpec (
717
- next_hidden_observation = NdUnboundedContinuousTensorSpec ((4 ,))
701
+ hidden_observation = NdUnboundedContinuousTensorSpec ((4 ,))
718
702
)
719
703
self .input_spec = CompositeSpec (
720
704
hidden_observation = NdUnboundedContinuousTensorSpec ((4 ,)),
@@ -728,9 +712,6 @@ def _reset(self, tensordict: TensorDict, **kwargs) -> TensorDict:
728
712
"hidden_observation" : self .input_spec ["hidden_observation" ].rand (
729
713
self .batch_size
730
714
),
731
- "next_hidden_observation" : self .observation_spec [
732
- "next_hidden_observation"
733
- ].rand (self .batch_size ),
734
715
},
735
716
batch_size = self .batch_size ,
736
717
device = self .device ,
0 commit comments