72
72
def _apply_to_composite (function ):
73
73
def new_fun (self , observation_spec ):
74
74
if isinstance (observation_spec , CompositeSpec ):
75
- d = copy ( observation_spec ._specs )
75
+ d = observation_spec ._specs
76
76
for key_in , key_out in zip (self .keys_in , self .keys_out ):
77
77
if key_in in observation_spec .keys ():
78
78
d [key_out ] = function (self , observation_spec [key_in ])
@@ -506,7 +506,9 @@ def __getattr__(self, attr: str) -> Any:
506
506
)
507
507
508
508
def __repr__ (self ) -> str :
509
- return f"TransformedEnv(env={ self .base_env } , transform={ self .transform } )"
509
+ env_str = indent (f"env={ self .base_env } " , 4 * " " )
510
+ t_str = indent (f"transform={ self .transform } " , 4 * " " )
511
+ return f"TransformedEnv(\n { env_str } ,\n { t_str } )"
510
512
511
513
def _erase_metadata (self ):
512
514
if self .cache_specs :
@@ -621,7 +623,9 @@ def __getitem__(self, item: Union[int, slice, List]) -> Union:
621
623
transform = self .transforms
622
624
transform = transform [item ]
623
625
if not isinstance (transform , Transform ):
624
- return Compose (* self .transforms [item ])
626
+ out = Compose (* self .transforms [item ])
627
+ out .set_parent (self .parent )
628
+ return out
625
629
return transform
626
630
627
631
def dump (self , ** kwargs ) -> None :
@@ -737,7 +741,7 @@ def _apply_transform(self, observation: torch.FloatTensor) -> torch.Tensor:
737
741
738
742
@_apply_to_composite
739
743
def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
740
- self ._pixel_observation (observation_spec )
744
+ observation_spec = self ._pixel_observation (deepcopy ( observation_spec ) )
741
745
observation_spec .shape = torch .Size (
742
746
[
743
747
* observation_spec .shape [:- 3 ],
@@ -747,13 +751,13 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
747
751
]
748
752
)
749
753
observation_spec .dtype = self .dtype
750
- observation_spec = observation_spec
751
754
return observation_spec
752
755
753
756
def _pixel_observation (self , spec : TensorSpec ) -> None :
754
- if isinstance (spec , BoundedTensorSpec ):
757
+ if isinstance (spec . space , ContinuousBox ):
755
758
spec .space .maximum = self ._apply_transform (spec .space .maximum )
756
759
spec .space .minimum = self ._apply_transform (spec .space .minimum )
760
+ return spec
757
761
758
762
759
763
class RewardClipping (Transform ):
@@ -899,6 +903,7 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
899
903
900
904
@_apply_to_composite
901
905
def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
906
+ observation_spec = deepcopy (observation_spec )
902
907
space = observation_spec .space
903
908
if isinstance (space , ContinuousBox ):
904
909
space .minimum = self ._apply_transform (space .minimum )
@@ -962,7 +967,8 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
962
967
}
963
968
)
964
969
else :
965
- _observation_spec = observation_spec
970
+ _observation_spec = deepcopy (observation_spec )
971
+
966
972
space = _observation_spec .space
967
973
if isinstance (space , ContinuousBox ):
968
974
space .minimum = self ._apply_transform (space .minimum )
@@ -1019,6 +1025,7 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
1019
1025
1020
1026
@_apply_to_composite
1021
1027
def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
1028
+ observation_spec = deepcopy (observation_spec )
1022
1029
space = observation_spec .space
1023
1030
if isinstance (space , ContinuousBox ):
1024
1031
space .minimum = self ._apply_transform (space .minimum )
@@ -1122,25 +1129,26 @@ def _transform_spec(self, spec: TensorSpec) -> None:
1122
1129
spec .shape = space .minimum .shape
1123
1130
else :
1124
1131
spec .shape = self ._apply_transform (torch .zeros (spec .shape )).shape
1132
+ return spec
1125
1133
1126
1134
def transform_action_spec (self , action_spec : TensorSpec ) -> TensorSpec :
1127
1135
if "action" in self .keys_inv_in :
1128
- self ._transform_spec (action_spec )
1136
+ action_spec = self ._transform_spec (deepcopy ( action_spec ) )
1129
1137
return action_spec
1130
1138
1131
1139
def transform_input_spec (self , input_spec : TensorSpec ) -> TensorSpec :
1132
1140
for key in self .keys_inv_in :
1133
- self ._transform_spec (input_spec [key ])
1141
+ input_spec = self ._transform_spec (deepcopy ( input_spec [key ]) )
1134
1142
return input_spec
1135
1143
1136
1144
def transform_reward_spec (self , reward_spec : TensorSpec ) -> TensorSpec :
1137
1145
if "reward" in self .keys_in :
1138
- self ._transform_spec (reward_spec )
1146
+ reward_spec = self ._transform_spec (deepcopy ( reward_spec ) )
1139
1147
return reward_spec
1140
1148
1141
1149
@_apply_to_composite
1142
1150
def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
1143
- self ._transform_spec (observation_spec )
1151
+ observation_spec = self ._transform_spec (deepcopy ( observation_spec ) )
1144
1152
return observation_spec
1145
1153
1146
1154
def __repr__ (self ) -> str :
@@ -1207,6 +1215,7 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
1207
1215
1208
1216
@_apply_to_composite
1209
1217
def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
1218
+ observation_spec = deepcopy (observation_spec )
1210
1219
space = observation_spec .space
1211
1220
if isinstance (space , ContinuousBox ):
1212
1221
space .minimum = self ._apply_transform (space .minimum )
@@ -1295,6 +1304,7 @@ def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor:
1295
1304
1296
1305
@_apply_to_composite
1297
1306
def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
1307
+ observation_spec = deepcopy (observation_spec )
1298
1308
space = observation_spec .space
1299
1309
if isinstance (space , ContinuousBox ):
1300
1310
space .minimum = self ._apply_transform (space .minimum )
0 commit comments