Skip to content

Commit dc0eebb

Browse files
jasonfkutJason Kutarniavmoens
authored
[BugFix] Prevent transform parent from being reassigned (#641)
* Prevent transform parent from being reassigned * Added missed resets, fixed flake8 error * Prevent transform.parent from modifying original transforms * Added test to ensure Transform.parent doesn't alter original _parents Co-authored-by: Jason Kutarnia <jasonkut@fb.com> Co-authored-by: vmoens <vincentmoens@gmail.com>
1 parent cd488b4 commit dc0eebb

File tree

2 files changed

+45
-8
lines changed

2 files changed

+45
-8
lines changed

test/test_transforms.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,27 @@ def test_nested_transformed_env():
327327
assert children[1] == t2
328328

329329

330+
def test_transform_parent():
331+
base_env = ContinuousActionVecMockEnv()
332+
t1 = RewardScaling(0, 1)
333+
t2 = RewardScaling(0, 2)
334+
env = TransformedEnv(TransformedEnv(base_env, t1), t2)
335+
t3 = RewardClipping(0.1, 0.5)
336+
env.append_transform(t3)
337+
338+
t1_parent_gt = t1._parent
339+
t2_parent_gt = t2._parent
340+
t3_parent_gt = t3._parent
341+
342+
_ = t1.parent
343+
_ = t2.parent
344+
_ = t3.parent
345+
346+
assert t1_parent_gt == t1._parent
347+
assert t2_parent_gt == t2._parent
348+
assert t3_parent_gt == t3._parent
349+
350+
330351
class TestTransforms:
331352
@pytest.mark.skipif(not _has_tv, reason="no torchvision")
332353
@pytest.mark.parametrize("interpolation", ["bilinear", "bicubic"])

torchrl/envs/transforms/transforms.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,13 @@ def __repr__(self) -> str:
204204
return f"{self.__class__.__name__}(keys={self.keys_in})"
205205

206206
def set_parent(self, parent: Union[Transform, EnvBase]) -> None:
207+
if self.__dict__["_parent"] is not None:
208+
raise AttributeError("parent of transform already set")
207209
self.__dict__["_parent"] = parent
208210

211+
def reset_parent(self) -> None:
212+
self.__dict__["_parent"] = None
213+
209214
@property
210215
def parent(self) -> EnvBase:
211216
if not hasattr(self, "_parent"):
@@ -226,15 +231,20 @@ def parent(self) -> EnvBase:
226231
raise ValueError(
227232
f"Compose parent was of type {type(compose_parent)} but expected TransformedEnv."
228233
)
234+
if compose_parent.transform is not compose:
235+
comp_parent_trans = copy(compose_parent.transform)
236+
comp_parent_trans.reset_parent()
237+
else:
238+
comp_parent_trans = None
229239
out = TransformedEnv(
230240
compose_parent.base_env,
231-
transform=compose_parent.transform
232-
if compose_parent.transform is not compose
233-
else None,
241+
transform=comp_parent_trans,
234242
)
235-
for transform in compose.transforms:
236-
if transform is self:
243+
for orig_trans in compose.transforms:
244+
if orig_trans is self:
237245
break
246+
transform = copy(orig_trans)
247+
transform.reset_parent()
238248
out.append_transform(transform)
239249
elif isinstance(parent, TransformedEnv):
240250
out = TransformedEnv(parent.base_env)
@@ -287,9 +297,16 @@ def __init__(
287297
# we don't use isinstance as some transforms may be subclassed from
288298
# Compose but with other features that we don't want to loose.
289299
transform = [transform]
300+
else:
301+
for t in transform:
302+
t.reset_parent()
290303
env_transform = env.transform
291304
if type(env_transform) is not Compose:
305+
env_transform.reset_parent()
292306
env_transform = [env_transform]
307+
else:
308+
for t in env_transform:
309+
t.reset_parent()
293310
transform = Compose(*env_transform, *transform).to(device)
294311
else:
295312
self._set_env(env, device)
@@ -474,9 +491,10 @@ def append_transform(self, transform: Transform) -> None:
474491
transform = transform.to(self.device)
475492
if not isinstance(self.transform, Compose):
476493
prev_transform = self.transform
494+
prev_transform.reset_parent()
477495
self.transform = Compose()
478496
self.transform.append(prev_transform)
479-
self.transform.set_parent(self)
497+
480498
self.transform.append(transform)
481499

482500
def insert_transform(self, index: int, transform: Transform) -> None:
@@ -538,8 +556,6 @@ def to(self, device: DEVICE_TYPING) -> TransformedEnv:
538556
def __setattr__(self, key, value):
539557
propobj = getattr(self.__class__, key, None)
540558

541-
if isinstance(value, Transform):
542-
value.set_parent(self)
543559
if isinstance(propobj, property):
544560
ancestors = list(__class__.__mro__)[::-1]
545561
while isinstance(propobj, property):

0 commit comments

Comments
 (0)