@@ -113,13 +113,6 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
113
113
"""Resets a tranform if it is stateful."""
114
114
return tensordict
115
115
116
- def _check_inplace (self ) -> None :
117
- if not hasattr (self , "inplace" ):
118
- raise AttributeError (
119
- f"Transform of class { self .__class__ .__name__ } has no "
120
- f"attribute inplace, consider implementing it."
121
- )
122
-
123
116
def init (self , tensordict ) -> None :
124
117
pass
125
118
@@ -134,11 +127,13 @@ def _apply_transform(self, obs: torch.Tensor) -> None:
134
127
135
128
def _call (self , tensordict : TensorDictBase ) -> TensorDictBase :
136
129
"""Reads the input tensordict, and for the selected keys, applies the transform."""
137
- self ._check_inplace ()
138
130
for in_key , out_key in zip (self .in_keys , self .out_keys ):
139
131
if in_key in tensordict .keys (include_nested = True ):
140
132
observation = self ._apply_transform (tensordict .get (in_key ))
141
- tensordict .set (out_key , observation , inplace = self .inplace )
133
+ tensordict .set (
134
+ out_key ,
135
+ observation ,
136
+ )
142
137
return tensordict
143
138
144
139
def forward (self , tensordict : TensorDictBase ) -> TensorDictBase :
@@ -160,11 +155,13 @@ def _inv_apply_transform(self, obs: torch.Tensor) -> torch.Tensor:
160
155
return obs
161
156
162
157
def _inv_call (self , tensordict : TensorDictBase ) -> TensorDictBase :
163
- self ._check_inplace ()
164
158
for in_key , out_key in zip (self .in_keys_inv , self .out_keys_inv ):
165
159
if in_key in tensordict .keys (include_nested = True ):
166
160
observation = self ._inv_apply_transform (tensordict .get (in_key ))
167
- tensordict .set (out_key , observation , inplace = self .inplace )
161
+ tensordict .set (
162
+ out_key ,
163
+ observation ,
164
+ )
168
165
return tensordict
169
166
170
167
def inv (self , tensordict : TensorDictBase ) -> TensorDictBase :
@@ -607,8 +604,6 @@ def __del__(self):
607
604
class ObservationTransform (Transform ):
608
605
"""Abstract class for transformations of the observations."""
609
606
610
- inplace = False
611
-
612
607
def __init__ (
613
608
self ,
614
609
in_keys : Optional [Sequence [str ]] = None ,
@@ -634,8 +629,6 @@ class Compose(Transform):
634
629
635
630
"""
636
631
637
- inplace = False
638
-
639
632
def __init__ (self , * transforms : Transform ):
640
633
super ().__init__ (in_keys = [])
641
634
self .transforms = nn .ModuleList (transforms )
@@ -773,8 +766,6 @@ class ToTensorImage(ObservationTransform):
773
766
torch.Size([1, 1, 3, 10, 11]) torch.float32
774
767
"""
775
768
776
- inplace = False
777
-
778
769
def __init__ (
779
770
self ,
780
771
unsqueeze : bool = False ,
@@ -827,8 +818,6 @@ class RewardClipping(Transform):
827
818
828
819
"""
829
820
830
- inplace = True
831
-
832
821
def __init__ (
833
822
self ,
834
823
clamp_min : float = None ,
@@ -850,11 +839,11 @@ def __init__(
850
839
851
840
def _apply_transform (self , reward : torch .Tensor ) -> torch .Tensor :
852
841
if self .clamp_max is not None and self .clamp_min is not None :
853
- reward = reward .clamp_ (self .clamp_min , self .clamp_max )
842
+ reward = reward .clamp (self .clamp_min , self .clamp_max )
854
843
elif self .clamp_min is not None :
855
- reward = reward .clamp_min_ (self .clamp_min )
844
+ reward = reward .clamp_min (self .clamp_min )
856
845
elif self .clamp_max is not None :
857
- reward = reward .clamp_max_ (self .clamp_max )
846
+ reward = reward .clamp_max (self .clamp_max )
858
847
return reward
859
848
860
849
def transform_reward_spec (self , reward_spec : TensorSpec ) -> TensorSpec :
@@ -884,8 +873,6 @@ def __repr__(self) -> str:
884
873
class BinarizeReward (Transform ):
885
874
"""Maps the reward to a binary value (0 or 1) if the reward is null or non-null, respectively."""
886
875
887
- inplace = True
888
-
889
876
def __init__ (
890
877
self ,
891
878
in_keys : Optional [Sequence [str ]] = None ,
@@ -917,8 +904,6 @@ class Resize(ObservationTransform):
917
904
interpolation (str): interpolation method
918
905
"""
919
906
920
- inplace = False
921
-
922
907
def __init__ (
923
908
self ,
924
909
w : int ,
@@ -986,8 +971,6 @@ class CenterCrop(ObservationTransform):
986
971
h (int, optional): resulting height. If None, then w is used (square crop).
987
972
"""
988
973
989
- inplace = False
990
-
991
974
def __init__ (
992
975
self ,
993
976
w : int ,
@@ -1043,8 +1026,6 @@ class FlattenObservation(ObservationTransform):
1043
1026
last_dim (int): last dimension of the dimensions to flatten.
1044
1027
"""
1045
1028
1046
- inplace = False
1047
-
1048
1029
def __init__ (
1049
1030
self ,
1050
1031
first_dim : int ,
@@ -1115,7 +1096,6 @@ class UnsqueezeTransform(Transform):
1115
1096
"""
1116
1097
1117
1098
invertible = True
1118
- inplace = False
1119
1099
1120
1100
@classmethod
1121
1101
def __new__ (cls , * args , ** kwargs ):
@@ -1232,7 +1212,6 @@ class SqueezeTransform(UnsqueezeTransform):
1232
1212
"""
1233
1213
1234
1214
invertible = True
1235
- inplace = False
1236
1215
1237
1216
def __init__ (
1238
1217
self ,
@@ -1269,8 +1248,6 @@ def inv(self, tensordict: TensorDictBase) -> TensorDictBase:
1269
1248
class GrayScale (ObservationTransform ):
1270
1249
"""Turns a pixel observation to grayscale."""
1271
1250
1272
- inplace = False
1273
-
1274
1251
def __init__ (self , in_keys : Optional [Sequence [str ]] = None ):
1275
1252
if in_keys is None :
1276
1253
in_keys = IMAGE_KEYS
@@ -1342,8 +1319,6 @@ class ObservationNorm(ObservationTransform):
1342
1319
1343
1320
"""
1344
1321
1345
- inplace = True
1346
-
1347
1322
def __init__ (
1348
1323
self ,
1349
1324
loc : Optional [float , torch .Tensor ] = None ,
@@ -1471,7 +1446,6 @@ def raise_initialization_exception(module):
1471
1446
raise RuntimeError ("Non-finite values found in loc" )
1472
1447
if not torch .isfinite (scale ).all ():
1473
1448
raise RuntimeError ("Non-finite values found in scale" )
1474
-
1475
1449
self .register_buffer ("loc" , loc )
1476
1450
self .register_buffer ("scale" , scale .clamp_min (self .eps ))
1477
1451
@@ -1662,8 +1636,6 @@ class RewardScaling(Transform):
1662
1636
as it is done for standardization. Default is `False`.
1663
1637
"""
1664
1638
1665
- inplace = True
1666
-
1667
1639
def __init__ (
1668
1640
self ,
1669
1641
loc : Union [float , torch .Tensor ],
@@ -1717,8 +1689,6 @@ def __repr__(self) -> str:
1717
1689
class FiniteTensorDictCheck (Transform ):
1718
1690
"""This transform will check that all the items of the tensordict are finite, and raise an exception if they are not."""
1719
1691
1720
- inplace = False
1721
-
1722
1692
def __init__ (self ):
1723
1693
super ().__init__ (in_keys = [])
1724
1694
@@ -1741,7 +1711,6 @@ class DoubleToFloat(Transform):
1741
1711
"""
1742
1712
1743
1713
invertible = True
1744
- inplace = False
1745
1714
1746
1715
def __init__ (
1747
1716
self ,
@@ -1835,7 +1804,6 @@ class CatTensors(Transform):
1835
1804
"""
1836
1805
1837
1806
invertible = False
1838
- inplace = False
1839
1807
1840
1808
def __init__ (
1841
1809
self ,
@@ -1992,8 +1960,6 @@ class DiscreteActionProjection(Transform):
1992
1960
tensor([1])
1993
1961
"""
1994
1962
1995
- inplace = False
1996
-
1997
1963
def __init__ (self , max_n : int , m : int , action_key : str = "action" ):
1998
1964
super ().__init__ ([action_key ])
1999
1965
self .max_n = max_n
@@ -2035,8 +2001,6 @@ class FrameSkipTransform(Transform):
2035
2001
2036
2002
"""
2037
2003
2038
- inplace = False
2039
-
2040
2004
def __init__ (self , frame_skip : int = 1 ):
2041
2005
super ().__init__ ([])
2042
2006
if frame_skip < 1 :
@@ -2069,8 +2033,6 @@ class NoopResetEnv(Transform):
2069
2033
2070
2034
"""
2071
2035
2072
- inplace = True
2073
-
2074
2036
def __init__ (self , noops : int = 30 , random : bool = True ):
2075
2037
"""Sample initial states by taking random number of no-ops on reset.
2076
2038
@@ -2169,8 +2131,6 @@ class TensorDictPrimer(Transform):
2169
2131
is_shared=False)
2170
2132
"""
2171
2133
2172
- inplace = False
2173
-
2174
2134
def __init__ (self , random = False , default_value = 0.0 , ** kwargs ):
2175
2135
self .primers = kwargs
2176
2136
self .random = random
@@ -2271,8 +2231,6 @@ class gSDENoise(Transform):
2271
2231
See the :func:`~torchrl.modules.models.exploration.gSDEModule' for more info.
2272
2232
"""
2273
2233
2274
- inplace = False
2275
-
2276
2234
def __init__ (
2277
2235
self ,
2278
2236
state_dim = None ,
@@ -2351,8 +2309,6 @@ class VecNorm(Transform):
2351
2309
2352
2310
"""
2353
2311
2354
- inplace = True
2355
-
2356
2312
def __init__ (
2357
2313
self ,
2358
2314
in_keys : Optional [Sequence [str ]] = None ,
@@ -2402,7 +2358,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
2402
2358
key , tensordict .get (key ), N = max (1 , tensordict .numel ())
2403
2359
)
2404
2360
2405
- tensordict .set_ (key , new_val )
2361
+ tensordict .set (key , new_val )
2406
2362
2407
2363
if self .lock is not None :
2408
2364
self .lock .release ()
@@ -2582,8 +2538,6 @@ class RewardSum(Transform):
2582
2538
this transform hos no effect.
2583
2539
"""
2584
2540
2585
- inplace = True
2586
-
2587
2541
def __init__ (
2588
2542
self ,
2589
2543
in_keys : Optional [Sequence [str ]] = None ,
@@ -2654,7 +2608,6 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
2654
2608
def _call (self , tensordict : TensorDictBase ) -> TensorDictBase :
2655
2609
"""Updates the episode rewards with the step rewards."""
2656
2610
# Sanity checks
2657
- self ._check_inplace ()
2658
2611
for in_key in self .in_keys :
2659
2612
if in_key not in tensordict .keys ():
2660
2613
return tensordict
@@ -2727,7 +2680,6 @@ class StepCounter(Transform):
2727
2680
"""
2728
2681
2729
2682
invertible = False
2730
- inplace = True
2731
2683
2732
2684
def __init__ (self , max_steps : Optional [int ] = None ):
2733
2685
if max_steps is not None and max_steps < 1 :
@@ -2799,8 +2751,6 @@ class ExcludeTransform(Transform):
2799
2751
2800
2752
"""
2801
2753
2802
- inplace = False
2803
-
2804
2754
def __init__ (self , * excluded_keys ):
2805
2755
super ().__init__ (in_keys = [], in_keys_inv = [], out_keys = [], out_keys_inv = [])
2806
2756
if not all (isinstance (item , str ) for item in excluded_keys ):
@@ -2840,8 +2790,6 @@ class SelectTransform(Transform):
2840
2790
2841
2791
"""
2842
2792
2843
- inplace = False
2844
-
2845
2793
def __init__ (self , * selected_keys ):
2846
2794
super ().__init__ (in_keys = [], in_keys_inv = [], out_keys = [], out_keys_inv = [])
2847
2795
if not all (isinstance (item , str ) for item in selected_keys ):
@@ -2887,7 +2835,6 @@ class TimeMaxPool(Transform):
2887
2835
T (int, optional): Number of time steps over which to apply max pooling.
2888
2836
"""
2889
2837
2890
- inplace = False
2891
2838
invertible = False
2892
2839
2893
2840
def __init__ (
0 commit comments