Skip to content

Commit 835cd40

Browse files
authored
[Refactor] Removing inplace transform attribute (#871)
1 parent 37e0c53 commit 835cd40

File tree

2 files changed

+12
-91
lines changed

2 files changed

+12
-91
lines changed

test/test_transforms.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,31 +1259,6 @@ def test_observationnorm_uninitialized_stats_error(self):
12591259
with pytest.raises(RuntimeError, match=err_msg):
12601260
transform._apply_transform(torch.Tensor([1]))
12611261

1262-
@pytest.mark.parametrize("device", get_available_devices())
1263-
def test_observationnorm_infinite_stats_error(self, device):
1264-
base_env = ContinuousActionVecMockEnv(
1265-
observation_spec=CompositeSpec(
1266-
observation=BoundedTensorSpec(
1267-
minimum=1, maximum=1, shape=torch.Size([1])
1268-
),
1269-
observation_orig=BoundedTensorSpec(
1270-
minimum=1, maximum=1, shape=torch.Size([1])
1271-
),
1272-
),
1273-
action_spec=BoundedTensorSpec(minimum=1, maximum=1, shape=torch.Size((1,))),
1274-
seed=0,
1275-
)
1276-
base_env.out_key = "observation"
1277-
t_env = TransformedEnv(
1278-
base_env,
1279-
transform=ObservationNorm(in_keys="observation"),
1280-
)
1281-
t_env.append_transform(ObservationNorm(in_keys="observation"))
1282-
err_msg = "Non-finite values found in"
1283-
with pytest.raises(RuntimeError, match=err_msg):
1284-
for transform in t_env.transform:
1285-
transform.init_stats(num_iter=100)
1286-
12871262
def test_catframes_transform_observation_spec(self):
12881263
N = 4
12891264
key1 = "first key"
@@ -1649,7 +1624,6 @@ def test_binarized_reward(self, device, batch):
16491624
device=device,
16501625
)
16511626
br(td)
1652-
assert td["reward"] is reward
16531627
assert (td["reward"] != reward_copy).all()
16541628
assert (td["misc"] == misc_copy).all()
16551629
assert (torch.count_nonzero(td["reward"]) == torch.sum(reward_copy > 0)).all()

torchrl/envs/transforms/transforms.py

Lines changed: 12 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,6 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
113113
"""Resets a tranform if it is stateful."""
114114
return tensordict
115115

116-
def _check_inplace(self) -> None:
117-
if not hasattr(self, "inplace"):
118-
raise AttributeError(
119-
f"Transform of class {self.__class__.__name__} has no "
120-
f"attribute inplace, consider implementing it."
121-
)
122-
123116
def init(self, tensordict) -> None:
124117
pass
125118

@@ -134,11 +127,13 @@ def _apply_transform(self, obs: torch.Tensor) -> None:
134127

135128
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
136129
"""Reads the input tensordict, and for the selected keys, applies the transform."""
137-
self._check_inplace()
138130
for in_key, out_key in zip(self.in_keys, self.out_keys):
139131
if in_key in tensordict.keys(include_nested=True):
140132
observation = self._apply_transform(tensordict.get(in_key))
141-
tensordict.set(out_key, observation, inplace=self.inplace)
133+
tensordict.set(
134+
out_key,
135+
observation,
136+
)
142137
return tensordict
143138

144139
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
@@ -160,11 +155,13 @@ def _inv_apply_transform(self, obs: torch.Tensor) -> torch.Tensor:
160155
return obs
161156

162157
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
163-
self._check_inplace()
164158
for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
165159
if in_key in tensordict.keys(include_nested=True):
166160
observation = self._inv_apply_transform(tensordict.get(in_key))
167-
tensordict.set(out_key, observation, inplace=self.inplace)
161+
tensordict.set(
162+
out_key,
163+
observation,
164+
)
168165
return tensordict
169166

170167
def inv(self, tensordict: TensorDictBase) -> TensorDictBase:
@@ -607,8 +604,6 @@ def __del__(self):
607604
class ObservationTransform(Transform):
608605
"""Abstract class for transformations of the observations."""
609606

610-
inplace = False
611-
612607
def __init__(
613608
self,
614609
in_keys: Optional[Sequence[str]] = None,
@@ -634,8 +629,6 @@ class Compose(Transform):
634629
635630
"""
636631

637-
inplace = False
638-
639632
def __init__(self, *transforms: Transform):
640633
super().__init__(in_keys=[])
641634
self.transforms = nn.ModuleList(transforms)
@@ -773,8 +766,6 @@ class ToTensorImage(ObservationTransform):
773766
torch.Size([1, 1, 3, 10, 11]) torch.float32
774767
"""
775768

776-
inplace = False
777-
778769
def __init__(
779770
self,
780771
unsqueeze: bool = False,
@@ -827,8 +818,6 @@ class RewardClipping(Transform):
827818
828819
"""
829820

830-
inplace = True
831-
832821
def __init__(
833822
self,
834823
clamp_min: float = None,
@@ -850,11 +839,11 @@ def __init__(
850839

851840
def _apply_transform(self, reward: torch.Tensor) -> torch.Tensor:
852841
if self.clamp_max is not None and self.clamp_min is not None:
853-
reward = reward.clamp_(self.clamp_min, self.clamp_max)
842+
reward = reward.clamp(self.clamp_min, self.clamp_max)
854843
elif self.clamp_min is not None:
855-
reward = reward.clamp_min_(self.clamp_min)
844+
reward = reward.clamp_min(self.clamp_min)
856845
elif self.clamp_max is not None:
857-
reward = reward.clamp_max_(self.clamp_max)
846+
reward = reward.clamp_max(self.clamp_max)
858847
return reward
859848

860849
def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec:
@@ -884,8 +873,6 @@ def __repr__(self) -> str:
884873
class BinarizeReward(Transform):
885874
"""Maps the reward to a binary value (0 or 1) if the reward is null or non-null, respectively."""
886875

887-
inplace = True
888-
889876
def __init__(
890877
self,
891878
in_keys: Optional[Sequence[str]] = None,
@@ -917,8 +904,6 @@ class Resize(ObservationTransform):
917904
interpolation (str): interpolation method
918905
"""
919906

920-
inplace = False
921-
922907
def __init__(
923908
self,
924909
w: int,
@@ -986,8 +971,6 @@ class CenterCrop(ObservationTransform):
986971
h (int, optional): resulting height. If None, then w is used (square crop).
987972
"""
988973

989-
inplace = False
990-
991974
def __init__(
992975
self,
993976
w: int,
@@ -1043,8 +1026,6 @@ class FlattenObservation(ObservationTransform):
10431026
last_dim (int): last dimension of the dimensions to flatten.
10441027
"""
10451028

1046-
inplace = False
1047-
10481029
def __init__(
10491030
self,
10501031
first_dim: int,
@@ -1115,7 +1096,6 @@ class UnsqueezeTransform(Transform):
11151096
"""
11161097

11171098
invertible = True
1118-
inplace = False
11191099

11201100
@classmethod
11211101
def __new__(cls, *args, **kwargs):
@@ -1232,7 +1212,6 @@ class SqueezeTransform(UnsqueezeTransform):
12321212
"""
12331213

12341214
invertible = True
1235-
inplace = False
12361215

12371216
def __init__(
12381217
self,
@@ -1269,8 +1248,6 @@ def inv(self, tensordict: TensorDictBase) -> TensorDictBase:
12691248
class GrayScale(ObservationTransform):
12701249
"""Turns a pixel observation to grayscale."""
12711250

1272-
inplace = False
1273-
12741251
def __init__(self, in_keys: Optional[Sequence[str]] = None):
12751252
if in_keys is None:
12761253
in_keys = IMAGE_KEYS
@@ -1342,8 +1319,6 @@ class ObservationNorm(ObservationTransform):
13421319
13431320
"""
13441321

1345-
inplace = True
1346-
13471322
def __init__(
13481323
self,
13491324
loc: Optional[float, torch.Tensor] = None,
@@ -1471,7 +1446,6 @@ def raise_initialization_exception(module):
14711446
raise RuntimeError("Non-finite values found in loc")
14721447
if not torch.isfinite(scale).all():
14731448
raise RuntimeError("Non-finite values found in scale")
1474-
14751449
self.register_buffer("loc", loc)
14761450
self.register_buffer("scale", scale.clamp_min(self.eps))
14771451

@@ -1662,8 +1636,6 @@ class RewardScaling(Transform):
16621636
as it is done for standardization. Default is `False`.
16631637
"""
16641638

1665-
inplace = True
1666-
16671639
def __init__(
16681640
self,
16691641
loc: Union[float, torch.Tensor],
@@ -1717,8 +1689,6 @@ def __repr__(self) -> str:
17171689
class FiniteTensorDictCheck(Transform):
17181690
"""This transform will check that all the items of the tensordict are finite, and raise an exception if they are not."""
17191691

1720-
inplace = False
1721-
17221692
def __init__(self):
17231693
super().__init__(in_keys=[])
17241694

@@ -1741,7 +1711,6 @@ class DoubleToFloat(Transform):
17411711
"""
17421712

17431713
invertible = True
1744-
inplace = False
17451714

17461715
def __init__(
17471716
self,
@@ -1835,7 +1804,6 @@ class CatTensors(Transform):
18351804
"""
18361805

18371806
invertible = False
1838-
inplace = False
18391807

18401808
def __init__(
18411809
self,
@@ -1992,8 +1960,6 @@ class DiscreteActionProjection(Transform):
19921960
tensor([1])
19931961
"""
19941962

1995-
inplace = False
1996-
19971963
def __init__(self, max_n: int, m: int, action_key: str = "action"):
19981964
super().__init__([action_key])
19991965
self.max_n = max_n
@@ -2035,8 +2001,6 @@ class FrameSkipTransform(Transform):
20352001
20362002
"""
20372003

2038-
inplace = False
2039-
20402004
def __init__(self, frame_skip: int = 1):
20412005
super().__init__([])
20422006
if frame_skip < 1:
@@ -2069,8 +2033,6 @@ class NoopResetEnv(Transform):
20692033
20702034
"""
20712035

2072-
inplace = True
2073-
20742036
def __init__(self, noops: int = 30, random: bool = True):
20752037
"""Sample initial states by taking random number of no-ops on reset.
20762038
@@ -2169,8 +2131,6 @@ class TensorDictPrimer(Transform):
21692131
is_shared=False)
21702132
"""
21712133

2172-
inplace = False
2173-
21742134
def __init__(self, random=False, default_value=0.0, **kwargs):
21752135
self.primers = kwargs
21762136
self.random = random
@@ -2271,8 +2231,6 @@ class gSDENoise(Transform):
22712231
See the :func:`~torchrl.modules.models.exploration.gSDEModule' for more info.
22722232
"""
22732233

2274-
inplace = False
2275-
22762234
def __init__(
22772235
self,
22782236
state_dim=None,
@@ -2351,8 +2309,6 @@ class VecNorm(Transform):
23512309
23522310
"""
23532311

2354-
inplace = True
2355-
23562312
def __init__(
23572313
self,
23582314
in_keys: Optional[Sequence[str]] = None,
@@ -2402,7 +2358,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
24022358
key, tensordict.get(key), N=max(1, tensordict.numel())
24032359
)
24042360

2405-
tensordict.set_(key, new_val)
2361+
tensordict.set(key, new_val)
24062362

24072363
if self.lock is not None:
24082364
self.lock.release()
@@ -2582,8 +2538,6 @@ class RewardSum(Transform):
25822538
this transform hos no effect.
25832539
"""
25842540

2585-
inplace = True
2586-
25872541
def __init__(
25882542
self,
25892543
in_keys: Optional[Sequence[str]] = None,
@@ -2654,7 +2608,6 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
26542608
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
26552609
"""Updates the episode rewards with the step rewards."""
26562610
# Sanity checks
2657-
self._check_inplace()
26582611
for in_key in self.in_keys:
26592612
if in_key not in tensordict.keys():
26602613
return tensordict
@@ -2727,7 +2680,6 @@ class StepCounter(Transform):
27272680
"""
27282681

27292682
invertible = False
2730-
inplace = True
27312683

27322684
def __init__(self, max_steps: Optional[int] = None):
27332685
if max_steps is not None and max_steps < 1:
@@ -2799,8 +2751,6 @@ class ExcludeTransform(Transform):
27992751
28002752
"""
28012753

2802-
inplace = False
2803-
28042754
def __init__(self, *excluded_keys):
28052755
super().__init__(in_keys=[], in_keys_inv=[], out_keys=[], out_keys_inv=[])
28062756
if not all(isinstance(item, str) for item in excluded_keys):
@@ -2840,8 +2790,6 @@ class SelectTransform(Transform):
28402790
28412791
"""
28422792

2843-
inplace = False
2844-
28452793
def __init__(self, *selected_keys):
28462794
super().__init__(in_keys=[], in_keys_inv=[], out_keys=[], out_keys_inv=[])
28472795
if not all(isinstance(item, str) for item in selected_keys):
@@ -2887,7 +2835,6 @@ class TimeMaxPool(Transform):
28872835
T (int, optional): Number of time steps over which to apply max pooling.
28882836
"""
28892837

2890-
inplace = False
28912838
invertible = False
28922839

28932840
def __init__(

0 commit comments

Comments
 (0)