Skip to content

Commit d9b6ed9

Browse files
romainjlnvmoens
andauthored
[Feature] Too many deepcopy in transforms.py (#625)
* Remove optional deepcopy in tranform.py * lint * Refactoring/Renaming deepcopy cleanup related tests * lint Co-authored-by: vmoens <vincentmoens@gmail.com>
1 parent dc0eebb commit d9b6ed9

File tree

2 files changed

+56
-18
lines changed

2 files changed

+56
-18
lines changed

test/test_transforms.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import argparse
66
from copy import copy, deepcopy
77

8+
import numpy as np
89
import pytest
910
import torch
1011
from _utils_internal import get_available_devices, retry
@@ -22,6 +23,7 @@
2223
NdUnboundedContinuousTensorSpec,
2324
TensorDict,
2425
UnboundedContinuousTensorSpec,
26+
BoundedTensorSpec,
2527
)
2628
from torchrl.envs import (
2729
BinarizeReward,
@@ -313,6 +315,50 @@ def test_added_transforms_are_in_eval_mode():
313315
assert t.transform[1].training
314316

315317

318+
class TestTransformedEnv:
319+
def test_independent_obs_specs_from_shared_env(self):
320+
obs_spec = CompositeSpec(
321+
next_observation=BoundedTensorSpec(minimum=0, maximum=10)
322+
)
323+
base_env = ContinuousActionVecMockEnv(observation_spec=obs_spec)
324+
t1 = TransformedEnv(base_env, transform=ObservationNorm(loc=3, scale=2))
325+
t2 = TransformedEnv(base_env, transform=ObservationNorm(loc=1, scale=6))
326+
327+
t1_obs_spec = t1.observation_spec
328+
t2_obs_spec = t2.observation_spec
329+
330+
assert t1_obs_spec["next_observation"].space.minimum == 3
331+
assert t1_obs_spec["next_observation"].space.maximum == 23
332+
333+
assert t2_obs_spec["next_observation"].space.minimum == 1
334+
assert t2_obs_spec["next_observation"].space.maximum == 61
335+
336+
assert base_env.observation_spec["next_observation"].space.minimum == 0
337+
assert base_env.observation_spec["next_observation"].space.maximum == 10
338+
339+
def test_independent_reward_specs_from_shared_env(self):
340+
reward_spec = UnboundedContinuousTensorSpec()
341+
base_env = ContinuousActionVecMockEnv(reward_spec=reward_spec)
342+
t1 = TransformedEnv(
343+
base_env, transform=RewardClipping(clamp_min=0, clamp_max=4)
344+
)
345+
t2 = TransformedEnv(
346+
base_env, transform=RewardClipping(clamp_min=-2, clamp_max=2)
347+
)
348+
349+
t1_reward_spec = t1.reward_spec
350+
t2_reward_spec = t2.reward_spec
351+
352+
assert t1_reward_spec.space.minimum == 0
353+
assert t1_reward_spec.space.maximum == 4
354+
355+
assert t2_reward_spec.space.minimum == -2
356+
assert t2_reward_spec.space.maximum == 2
357+
358+
assert base_env.reward_spec.space.minimum == -np.inf
359+
assert base_env.reward_spec.space.maximum == np.inf
360+
361+
316362
def test_nested_transformed_env():
317363
base_env = ContinuousActionVecMockEnv()
318364
t1 = RewardScaling(0, 1)

torchrl/envs/transforms/transforms.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ def _apply_transform(self, observation: torch.FloatTensor) -> torch.Tensor:
757757

758758
@_apply_to_composite
759759
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
760-
observation_spec = self._pixel_observation(deepcopy(observation_spec))
760+
observation_spec = self._pixel_observation(observation_spec)
761761
observation_spec.shape = torch.Size(
762762
[
763763
*observation_spec.shape[:-3],
@@ -913,7 +913,6 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
913913

914914
@_apply_to_composite
915915
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
916-
observation_spec = deepcopy(observation_spec)
917916
space = observation_spec.space
918917
if isinstance(space, ContinuousBox):
919918
space.minimum = self._apply_transform(space.minimum)
@@ -970,20 +969,17 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
970969
for key, _obs_spec in observation_spec._specs.items()
971970
}
972971
)
973-
else:
974-
_observation_spec = deepcopy(observation_spec)
975972

976-
space = _observation_spec.space
973+
space = observation_spec.space
977974
if isinstance(space, ContinuousBox):
978975
space.minimum = self._apply_transform(space.minimum)
979976
space.maximum = self._apply_transform(space.maximum)
980-
_observation_spec.shape = space.minimum.shape
977+
observation_spec.shape = space.minimum.shape
981978
else:
982-
_observation_spec.shape = self._apply_transform(
983-
torch.zeros(_observation_spec.shape)
979+
observation_spec.shape = self._apply_transform(
980+
torch.zeros(observation_spec.shape)
984981
).shape
985982

986-
observation_spec = _observation_spec
987983
return observation_spec
988984

989985
def __repr__(self) -> str:
@@ -1036,7 +1032,6 @@ def set_parent(self, parent: Union[Transform, EnvBase]) -> None:
10361032

10371033
@_apply_to_composite
10381034
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
1039-
observation_spec = deepcopy(observation_spec)
10401035
space = observation_spec.space
10411036

10421037
if isinstance(space, ContinuousBox):
@@ -1139,17 +1134,17 @@ def _transform_spec(self, spec: TensorSpec) -> None:
11391134

11401135
def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
11411136
for key in self.keys_inv_in:
1142-
input_spec = self._transform_spec(deepcopy(input_spec[key]))
1137+
input_spec = self._transform_spec(input_spec[key])
11431138
return input_spec
11441139

11451140
def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec:
11461141
if "reward" in self.keys_in:
1147-
reward_spec = self._transform_spec(deepcopy(reward_spec))
1142+
reward_spec = self._transform_spec(reward_spec)
11481143
return reward_spec
11491144

11501145
@_apply_to_composite
11511146
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
1152-
observation_spec = self._transform_spec(deepcopy(observation_spec))
1147+
observation_spec = self._transform_spec(observation_spec)
11531148
return observation_spec
11541149

11551150
def __repr__(self) -> str:
@@ -1213,7 +1208,6 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
12131208

12141209
@_apply_to_composite
12151210
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
1216-
observation_spec = deepcopy(observation_spec)
12171211
space = observation_spec.space
12181212
if isinstance(space, ContinuousBox):
12191213
space.minimum = self._apply_transform(space.minimum)
@@ -1303,7 +1297,6 @@ def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor:
13031297

13041298
@_apply_to_composite
13051299
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
1306-
observation_spec = deepcopy(observation_spec)
13071300
space = observation_spec.space
13081301
if isinstance(space, ContinuousBox):
13091302
space.minimum = self._apply_transform(space.minimum)
@@ -1756,9 +1749,8 @@ def _inv_apply_transform(self, action: torch.Tensor) -> torch.Tensor:
17561749
return action
17571750

17581751
def tranform_input_spec(self, input_spec: CompositeSpec):
1759-
input_spec_out = deepcopy(input_spec)
1760-
input_spec_out["action"] = self.transform_action_spec(input_spec_out["action"])
1761-
return input_spec_out
1752+
input_spec["action"] = self.transform_action_spec(input_spec["action"])
1753+
return input_spec
17621754

17631755
def __repr__(self) -> str:
17641756
return (

0 commit comments

Comments
 (0)