Skip to content

Commit b876ce6

Browse files
authored
[BugFix] Compose cloning fix (#899)
1 parent 4317aa7 commit b876ce6

File tree

2 files changed

+56
-2
lines changed

2 files changed

+56
-2
lines changed

test/test_transforms.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6059,6 +6059,42 @@ def test_smoke_compose_transform(transform):
60596059
Compose(transform())
60606060

60616061

6062+
@pytest.mark.parametrize("transform", transforms)
6063+
def test_clone_parent(transform):
6064+
base_env1 = ContinuousActionVecMockEnv()
6065+
base_env2 = ContinuousActionVecMockEnv()
6066+
env = TransformedEnv(base_env1, transform())
6067+
env_clone = TransformedEnv(base_env2, env.transform.clone())
6068+
6069+
assert env_clone.transform.parent.base_env is not base_env1
6070+
assert env_clone.transform.parent.base_env is base_env2
6071+
assert env.transform.parent.base_env is not base_env2
6072+
assert env.transform.parent.base_env is base_env1
6073+
6074+
6075+
@pytest.mark.parametrize("transform", transforms)
6076+
def test_clone_parent_compose(transform):
6077+
base_env1 = ContinuousActionVecMockEnv()
6078+
base_env2 = ContinuousActionVecMockEnv()
6079+
env = TransformedEnv(base_env1, Compose(ToTensorImage(), transform()))
6080+
t = env.transform.clone()
6081+
6082+
assert t.parent is None
6083+
assert t[0].parent is None
6084+
assert t[1].parent is None
6085+
6086+
env_clone = TransformedEnv(base_env2, Compose(ToTensorImage(), *t))
6087+
6088+
assert env_clone.transform[0].parent.base_env is not base_env1
6089+
assert env_clone.transform[0].parent.base_env is base_env2
6090+
assert env.transform[0].parent.base_env is not base_env2
6091+
assert env.transform[0].parent.base_env is base_env1
6092+
assert env_clone.transform[1].parent.base_env is not base_env1
6093+
assert env_clone.transform[1].parent.base_env is base_env2
6094+
assert env.transform[1].parent.base_env is not base_env2
6095+
assert env.transform[1].parent.base_env is base_env1
6096+
6097+
60626098
if __name__ == "__main__":
60636099
args, unknown = argparse.ArgumentParser().parse_known_args()
60646100
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/envs/transforms/transforms.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,10 @@ def reset_parent(self) -> None:
265265

266266
def clone(self):
267267
self_copy = copy(self)
268-
self_copy.reset_parent()
268+
state = copy(self.__dict__)
269+
state["_container"] = None
270+
state["_parent"] = None
271+
self_copy.__dict__.update(state)
269272
return self_copy
270273

271274
@property
@@ -778,7 +781,11 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Compose:
778781
return super().to(dest)
779782

780783
def __iter__(self):
781-
return iter(self.transforms)
784+
# We clone the transforms to be able to do
785+
# env2 = TransformedEnv(base_env, *env1.transform.clone())
786+
# which will otherwise result in an error because all the transforms
787+
# from the Compose already have a container.
788+
yield from (t.clone() for t in self.transforms)
782789

783790
def __len__(self):
784791
return len(self.transforms)
@@ -794,6 +801,17 @@ def empty_cache(self):
794801
t.empty_cache()
795802
super().empty_cache()
796803

804+
def reset_parent(self):
805+
for t in self.transforms:
806+
t.reset_parent()
807+
super().reset_parent()
808+
809+
def clone(self):
810+
transforms = []
811+
for t in self.transforms:
812+
transforms.append(t.clone())
813+
return Compose(*transforms)
814+
797815

798816
class ToTensorImage(ObservationTransform):
799817
"""Transforms a numpy-like image (3 x W x H) to a pytorch image (3 x W x H).

0 commit comments

Comments
 (0)