Skip to content

Commit e0e3cf3

Browse files
authored
[BugFix] deepcopy specs before transforming (#461)
1 parent ef1bf20 commit e0e3cf3

File tree

2 files changed

+23
-12
lines changed

2 files changed

+23
-12
lines changed

torchrl/envs/transforms/r3m.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
8787
keys = [key for key in observation_spec._specs.keys() if key in self.keys_in]
8888
device = observation_spec[keys[0]].device
8989

90+
observation_spec = CompositeSpec(**observation_spec)
9091
if self.del_keys:
9192
for key_in in keys:
9293
del observation_spec[key_in]
@@ -272,7 +273,7 @@ def _init(self):
272273
model_name=model_name,
273274
del_keys=True,
274275
)
275-
transforms = [*transforms, normalize, network]
276+
transforms = [*transforms, network]
276277

277278
for transform in transforms:
278279
self.append(transform)

torchrl/envs/transforms/transforms.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
def _apply_to_composite(function):
7373
def new_fun(self, observation_spec):
7474
if isinstance(observation_spec, CompositeSpec):
75-
d = copy(observation_spec._specs)
75+
d = observation_spec._specs
7676
for key_in, key_out in zip(self.keys_in, self.keys_out):
7777
if key_in in observation_spec.keys():
7878
d[key_out] = function(self, observation_spec[key_in])
@@ -506,7 +506,9 @@ def __getattr__(self, attr: str) -> Any:
506506
)
507507

508508
def __repr__(self) -> str:
509-
return f"TransformedEnv(env={self.base_env}, transform={self.transform})"
509+
env_str = indent(f"env={self.base_env}", 4 * " ")
510+
t_str = indent(f"transform={self.transform}", 4 * " ")
511+
return f"TransformedEnv(\n{env_str},\n{t_str})"
510512

511513
def _erase_metadata(self):
512514
if self.cache_specs:
@@ -621,7 +623,9 @@ def __getitem__(self, item: Union[int, slice, List]) -> Union:
621623
transform = self.transforms
622624
transform = transform[item]
623625
if not isinstance(transform, Transform):
624-
return Compose(*self.transforms[item])
626+
out = Compose(*self.transforms[item])
627+
out.set_parent(self.parent)
628+
return out
625629
return transform
626630

627631
def dump(self, **kwargs) -> None:
@@ -737,7 +741,7 @@ def _apply_transform(self, observation: torch.FloatTensor) -> torch.Tensor:
737741

738742
@_apply_to_composite
739743
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
740-
self._pixel_observation(observation_spec)
744+
observation_spec = self._pixel_observation(deepcopy(observation_spec))
741745
observation_spec.shape = torch.Size(
742746
[
743747
*observation_spec.shape[:-3],
@@ -747,13 +751,13 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
747751
]
748752
)
749753
observation_spec.dtype = self.dtype
750-
observation_spec = observation_spec
751754
return observation_spec
752755

753756
def _pixel_observation(self, spec: TensorSpec) -> None:
754-
if isinstance(spec, BoundedTensorSpec):
757+
if isinstance(spec.space, ContinuousBox):
755758
spec.space.maximum = self._apply_transform(spec.space.maximum)
756759
spec.space.minimum = self._apply_transform(spec.space.minimum)
760+
return spec
757761

758762

759763
class RewardClipping(Transform):
@@ -899,6 +903,7 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
899903

900904
@_apply_to_composite
901905
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
906+
observation_spec = deepcopy(observation_spec)
902907
space = observation_spec.space
903908
if isinstance(space, ContinuousBox):
904909
space.minimum = self._apply_transform(space.minimum)
@@ -962,7 +967,8 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
962967
}
963968
)
964969
else:
965-
_observation_spec = observation_spec
970+
_observation_spec = deepcopy(observation_spec)
971+
966972
space = _observation_spec.space
967973
if isinstance(space, ContinuousBox):
968974
space.minimum = self._apply_transform(space.minimum)
@@ -1019,6 +1025,7 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
10191025

10201026
@_apply_to_composite
10211027
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
1028+
observation_spec = deepcopy(observation_spec)
10221029
space = observation_spec.space
10231030
if isinstance(space, ContinuousBox):
10241031
space.minimum = self._apply_transform(space.minimum)
@@ -1122,25 +1129,26 @@ def _transform_spec(self, spec: TensorSpec) -> None:
11221129
spec.shape = space.minimum.shape
11231130
else:
11241131
spec.shape = self._apply_transform(torch.zeros(spec.shape)).shape
1132+
return spec
11251133

11261134
def transform_action_spec(self, action_spec: TensorSpec) -> TensorSpec:
11271135
if "action" in self.keys_inv_in:
1128-
self._transform_spec(action_spec)
1136+
action_spec = self._transform_spec(deepcopy(action_spec))
11291137
return action_spec
11301138

11311139
def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
11321140
for key in self.keys_inv_in:
1133-
self._transform_spec(input_spec[key])
1141+
input_spec = self._transform_spec(deepcopy(input_spec[key]))
11341142
return input_spec
11351143

11361144
def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec:
11371145
if "reward" in self.keys_in:
1138-
self._transform_spec(reward_spec)
1146+
reward_spec = self._transform_spec(deepcopy(reward_spec))
11391147
return reward_spec
11401148

11411149
@_apply_to_composite
11421150
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
1143-
self._transform_spec(observation_spec)
1151+
observation_spec = self._transform_spec(deepcopy(observation_spec))
11441152
return observation_spec
11451153

11461154
def __repr__(self) -> str:
@@ -1207,6 +1215,7 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
12071215

12081216
@_apply_to_composite
12091217
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
1218+
observation_spec = deepcopy(observation_spec)
12101219
space = observation_spec.space
12111220
if isinstance(space, ContinuousBox):
12121221
space.minimum = self._apply_transform(space.minimum)
@@ -1295,6 +1304,7 @@ def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor:
12951304

12961305
@_apply_to_composite
12971306
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
1307+
observation_spec = deepcopy(observation_spec)
12981308
space = observation_spec.space
12991309
if isinstance(space, ContinuousBox):
13001310
space.minimum = self._apply_transform(space.minimum)

0 commit comments

Comments
 (0)