@@ -204,8 +204,13 @@ def __repr__(self) -> str:
204
204
return f"{ self .__class__ .__name__ } (keys={ self .keys_in } )"
205
205
206
206
def set_parent (self , parent : Union [Transform , EnvBase ]) -> None :
207
+ if self .__dict__ ["_parent" ] is not None :
208
+ raise AttributeError ("parent of transform already set" )
207
209
self .__dict__ ["_parent" ] = parent
208
210
211
+ def reset_parent (self ) -> None :
212
+ self .__dict__ ["_parent" ] = None
213
+
209
214
@property
210
215
def parent (self ) -> EnvBase :
211
216
if not hasattr (self , "_parent" ):
@@ -226,15 +231,20 @@ def parent(self) -> EnvBase:
226
231
raise ValueError (
227
232
f"Compose parent was of type { type (compose_parent )} but expected TransformedEnv."
228
233
)
234
+ if compose_parent .transform is not compose :
235
+ comp_parent_trans = copy (compose_parent .transform )
236
+ comp_parent_trans .reset_parent ()
237
+ else :
238
+ comp_parent_trans = None
229
239
out = TransformedEnv (
230
240
compose_parent .base_env ,
231
- transform = compose_parent .transform
232
- if compose_parent .transform is not compose
233
- else None ,
241
+ transform = comp_parent_trans ,
234
242
)
235
- for transform in compose .transforms :
236
- if transform is self :
243
+ for orig_trans in compose .transforms :
244
+ if orig_trans is self :
237
245
break
246
+ transform = copy (orig_trans )
247
+ transform .reset_parent ()
238
248
out .append_transform (transform )
239
249
elif isinstance (parent , TransformedEnv ):
240
250
out = TransformedEnv (parent .base_env )
@@ -287,9 +297,16 @@ def __init__(
287
297
# we don't use isinstance as some transforms may be subclassed from
288
298
# Compose but with other features that we don't want to loose.
289
299
transform = [transform ]
300
+ else :
301
+ for t in transform :
302
+ t .reset_parent ()
290
303
env_transform = env .transform
291
304
if type (env_transform ) is not Compose :
305
+ env_transform .reset_parent ()
292
306
env_transform = [env_transform ]
307
+ else :
308
+ for t in env_transform :
309
+ t .reset_parent ()
293
310
transform = Compose (* env_transform , * transform ).to (device )
294
311
else :
295
312
self ._set_env (env , device )
@@ -474,9 +491,10 @@ def append_transform(self, transform: Transform) -> None:
474
491
transform = transform .to (self .device )
475
492
if not isinstance (self .transform , Compose ):
476
493
prev_transform = self .transform
494
+ prev_transform .reset_parent ()
477
495
self .transform = Compose ()
478
496
self .transform .append (prev_transform )
479
- self . transform . set_parent ( self )
497
+
480
498
self .transform .append (transform )
481
499
482
500
def insert_transform (self , index : int , transform : Transform ) -> None :
@@ -538,8 +556,6 @@ def to(self, device: DEVICE_TYPING) -> TransformedEnv:
538
556
def __setattr__ (self , key , value ):
539
557
propobj = getattr (self .__class__ , key , None )
540
558
541
- if isinstance (value , Transform ):
542
- value .set_parent (self )
543
559
if isinstance (propobj , property ):
544
560
ancestors = list (__class__ .__mro__ )[::- 1 ]
545
561
while isinstance (propobj , property ):
0 commit comments