@@ -757,7 +757,7 @@ def _apply_transform(self, observation: torch.FloatTensor) -> torch.Tensor:
757
757
758
758
@_apply_to_composite
759
759
def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
760
- observation_spec = self ._pixel_observation (deepcopy ( observation_spec ) )
760
+ observation_spec = self ._pixel_observation (observation_spec )
761
761
observation_spec .shape = torch .Size (
762
762
[
763
763
* observation_spec .shape [:- 3 ],
@@ -913,7 +913,6 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
913
913
914
914
@_apply_to_composite
915
915
def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
916
- observation_spec = deepcopy (observation_spec )
917
916
space = observation_spec .space
918
917
if isinstance (space , ContinuousBox ):
919
918
space .minimum = self ._apply_transform (space .minimum )
@@ -970,20 +969,17 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
970
969
for key , _obs_spec in observation_spec ._specs .items ()
971
970
}
972
971
)
973
- else :
974
- _observation_spec = deepcopy (observation_spec )
975
972
976
- space = _observation_spec .space
973
+ space = observation_spec .space
977
974
if isinstance (space , ContinuousBox ):
978
975
space .minimum = self ._apply_transform (space .minimum )
979
976
space .maximum = self ._apply_transform (space .maximum )
980
- _observation_spec .shape = space .minimum .shape
977
+ observation_spec .shape = space .minimum .shape
981
978
else :
982
- _observation_spec .shape = self ._apply_transform (
983
- torch .zeros (_observation_spec .shape )
979
+ observation_spec .shape = self ._apply_transform (
980
+ torch .zeros (observation_spec .shape )
984
981
).shape
985
982
986
- observation_spec = _observation_spec
987
983
return observation_spec
988
984
989
985
def __repr__ (self ) -> str :
@@ -1036,7 +1032,6 @@ def set_parent(self, parent: Union[Transform, EnvBase]) -> None:
1036
1032
1037
1033
@_apply_to_composite
1038
1034
def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
1039
- observation_spec = deepcopy (observation_spec )
1040
1035
space = observation_spec .space
1041
1036
1042
1037
if isinstance (space , ContinuousBox ):
@@ -1139,17 +1134,17 @@ def _transform_spec(self, spec: TensorSpec) -> None:
1139
1134
1140
1135
def transform_input_spec (self , input_spec : TensorSpec ) -> TensorSpec :
1141
1136
for key in self .keys_inv_in :
1142
- input_spec = self ._transform_spec (deepcopy ( input_spec [key ]) )
1137
+ input_spec = self ._transform_spec (input_spec [key ])
1143
1138
return input_spec
1144
1139
1145
1140
def transform_reward_spec (self , reward_spec : TensorSpec ) -> TensorSpec :
1146
1141
if "reward" in self .keys_in :
1147
- reward_spec = self ._transform_spec (deepcopy ( reward_spec ) )
1142
+ reward_spec = self ._transform_spec (reward_spec )
1148
1143
return reward_spec
1149
1144
1150
1145
@_apply_to_composite
1151
1146
def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
1152
- observation_spec = self ._transform_spec (deepcopy ( observation_spec ) )
1147
+ observation_spec = self ._transform_spec (observation_spec )
1153
1148
return observation_spec
1154
1149
1155
1150
def __repr__ (self ) -> str :
@@ -1213,7 +1208,6 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
1213
1208
1214
1209
@_apply_to_composite
1215
1210
def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
1216
- observation_spec = deepcopy (observation_spec )
1217
1211
space = observation_spec .space
1218
1212
if isinstance (space , ContinuousBox ):
1219
1213
space .minimum = self ._apply_transform (space .minimum )
@@ -1303,7 +1297,6 @@ def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor:
1303
1297
1304
1298
@_apply_to_composite
1305
1299
def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
1306
- observation_spec = deepcopy (observation_spec )
1307
1300
space = observation_spec .space
1308
1301
if isinstance (space , ContinuousBox ):
1309
1302
space .minimum = self ._apply_transform (space .minimum )
@@ -1756,9 +1749,8 @@ def _inv_apply_transform(self, action: torch.Tensor) -> torch.Tensor:
1756
1749
return action
1757
1750
1758
1751
def tranform_input_spec (self , input_spec : CompositeSpec ):
1759
- input_spec_out = deepcopy (input_spec )
1760
- input_spec_out ["action" ] = self .transform_action_spec (input_spec_out ["action" ])
1761
- return input_spec_out
1752
+ input_spec ["action" ] = self .transform_action_spec (input_spec ["action" ])
1753
+ return input_spec
1762
1754
1763
1755
def __repr__ (self ) -> str :
1764
1756
return (
0 commit comments