Skip to content

Commit 9003a56

Browse files
authored
[Feature] Exclude and select transforms (#832)
1 parent 4ebc764 commit 9003a56

File tree

5 files changed

+172
-20
lines changed

5 files changed

+172
-20
lines changed

docs/source/reference/envs.rst

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,9 @@ in the environment. The keys to be included in this inverse transform are passed
208208
CatTensors
209209
CenterCrop
210210
Compose
211+
DiscreteActionProjection
211212
DoubleToFloat
213+
ExcludeTransform
212214
FiniteTensorDictCheck
213215
FlattenObservation
214216
FrameSkipTransform
@@ -218,19 +220,20 @@ in the environment. The keys to be included in this inverse transform are passed
218220
ObservationNorm
219221
ObservationTransform
220222
PinMemoryTransform
223+
R3MTransform
221224
Resize
222225
RewardClipping
223226
RewardScaling
224227
RewardSum
228+
SelectTransform
225229
SqueezeTransform
226230
StepCounter
227231
TensorDictPrimer
228232
ToTensorImage
229233
UnsqueezeTransform
230234
VecNorm
231-
R3MTransform
232-
VIPTransform
233235
VIPRewardTransform
236+
VIPTransform
234237

235238
Recorders
236239
---------

test/test_transforms.py

Lines changed: 64 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,47 +22,49 @@
2222
MockBatchedLockedEnv,
2323
MockBatchedUnLockedEnv,
2424
)
25-
from tensordict import TensorDict
25+
from tensordict.tensordict import TensorDict, TensorDictBase
2626
from torch import multiprocessing as mp, Tensor
2727
from torchrl._utils import prod
2828
from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec
2929
from torchrl.envs import (
3030
BinarizeReward,
3131
CatFrames,
3232
CatTensors,
33+
CenterCrop,
3334
Compose,
35+
DiscreteActionProjection,
3436
DoubleToFloat,
37+
EnvBase,
3538
EnvCreator,
39+
ExcludeTransform,
3640
FiniteTensorDictCheck,
3741
FlattenObservation,
42+
FrameSkipTransform,
3843
GrayScale,
44+
gSDENoise,
45+
NoopResetEnv,
3946
ObservationNorm,
4047
ParallelEnv,
48+
PinMemoryTransform,
4149
R3MTransform,
4250
Resize,
4351
RewardClipping,
4452
RewardScaling,
4553
RewardSum,
54+
SelectTransform,
4655
SerialEnv,
56+
SqueezeTransform,
4757
StepCounter,
58+
TensorDictPrimer,
4859
ToTensorImage,
60+
TransformedEnv,
61+
UnsqueezeTransform,
4962
VIPTransform,
5063
)
5164
from torchrl.envs.libs.gym import _has_gym, GymEnv
52-
from torchrl.envs.transforms import TransformedEnv, VecNorm
65+
from torchrl.envs.transforms import VecNorm
5366
from torchrl.envs.transforms.r3m import _R3MNet
54-
from torchrl.envs.transforms.transforms import (
55-
_has_tv,
56-
CenterCrop,
57-
DiscreteActionProjection,
58-
FrameSkipTransform,
59-
gSDENoise,
60-
NoopResetEnv,
61-
PinMemoryTransform,
62-
SqueezeTransform,
63-
TensorDictPrimer,
64-
UnsqueezeTransform,
65-
)
67+
from torchrl.envs.transforms.transforms import _has_tv
6668
from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform
6769
from torchrl.envs.utils import check_env_specs
6870

@@ -2268,6 +2270,54 @@ def test_batch_unlocked_with_batch_size_transformed(device):
22682270
env.step(td_expanded)
22692271

22702272

2273+
class TestExcludeSelect:
2274+
class EnvWithManyKeys(EnvBase):
2275+
def __init__(self):
2276+
super().__init__()
2277+
self.observation_spec = CompositeSpec(
2278+
a=UnboundedContinuousTensorSpec(3),
2279+
b=UnboundedContinuousTensorSpec(3),
2280+
c=UnboundedContinuousTensorSpec(3),
2281+
)
2282+
self.reward_spec = UnboundedContinuousTensorSpec(1)
2283+
self.input_spec = CompositeSpec(action=UnboundedContinuousTensorSpec(2))
2284+
2285+
def _step(
2286+
self,
2287+
tensordict: TensorDictBase,
2288+
) -> TensorDictBase:
2289+
return self.observation_spec.rand().update(
2290+
{
2291+
"reward": self.reward_spec.rand(),
2292+
"done": torch.zeros(1, dtype=torch.bool),
2293+
}
2294+
)
2295+
2296+
def _reset(self, tensordict: TensorDictBase) -> TensorDictBase:
2297+
return self.observation_spec.rand().update(
2298+
{"done": torch.zeros(1, dtype=torch.bool)}
2299+
)
2300+
2301+
def _set_seed(self, seed):
2302+
return seed + 1
2303+
2304+
def test_exclude(self):
2305+
base_env = TestExcludeSelect.EnvWithManyKeys()
2306+
env = TransformedEnv(base_env, ExcludeTransform("a"))
2307+
check_env_specs(env)
2308+
assert "a" not in env.reset().keys()
2309+
assert "b" in env.reset().keys()
2310+
assert "c" in env.reset().keys()
2311+
2312+
def test_select(self):
2313+
base_env = TestExcludeSelect.EnvWithManyKeys()
2314+
env = TransformedEnv(base_env, SelectTransform("b", "c"))
2315+
check_env_specs(env)
2316+
assert "a" not in env.reset().keys()
2317+
assert "b" in env.reset().keys()
2318+
assert "c" in env.reset().keys()
2319+
2320+
22712321
transforms = [
22722322
ToTensorImage,
22732323
pytest.param(

torchrl/envs/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@
1313
CatTensors,
1414
CenterCrop,
1515
Compose,
16+
DiscreteActionProjection,
1617
DoubleToFloat,
18+
ExcludeTransform,
1719
FiniteTensorDictCheck,
1820
FlattenObservation,
21+
FrameSkipTransform,
1922
GrayScale,
2023
gSDENoise,
2124
NoopResetEnv,
@@ -27,13 +30,16 @@
2730
RewardClipping,
2831
RewardScaling,
2932
RewardSum,
33+
SelectTransform,
34+
SqueezeTransform,
3035
StepCounter,
3136
TensorDictPrimer,
3237
ToTensorImage,
3338
Transform,
3439
TransformedEnv,
3540
UnsqueezeTransform,
3641
VecNorm,
42+
VIPRewardTransform,
3743
VIPTransform,
3844
)
3945
from .vec_env import ParallelEnv, SerialEnv

torchrl/envs/transforms/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
CatTensors,
1111
CenterCrop,
1212
Compose,
13+
DiscreteActionProjection,
1314
DoubleToFloat,
15+
ExcludeTransform,
1416
FiniteTensorDictCheck,
1517
FlattenObservation,
1618
FrameSkipTransform,
@@ -24,6 +26,7 @@
2426
RewardClipping,
2527
RewardScaling,
2628
RewardSum,
29+
SelectTransform,
2730
SqueezeTransform,
2831
StepCounter,
2932
TensorDictPrimer,

torchrl/envs/transforms/transforms.py

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -448,13 +448,17 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
448448
tensordict = tensordict.clone(False)
449449
tensordict_in = self.transform.inv(tensordict)
450450
tensordict_out = self.base_env._step(tensordict_in)
451-
tensordict_out = tensordict_out.update(
452-
tensordict.exclude(*tensordict_out.keys())
451+
tensordict_out = (
452+
tensordict_out.update( # update the output with the original tensordict
453+
tensordict.exclude(
454+
*tensordict_out.keys()
455+
) # exclude the newly written keys
456+
)
453457
)
454458
next_tensordict = self.transform._step(tensordict_out)
455-
tensordict_out.update(next_tensordict, inplace=False)
459+
# tensordict_out.update(next_tensordict, inplace=False)
456460

457-
return tensordict_out
461+
return next_tensordict
458462

459463
def set_seed(
460464
self, seed: Optional[int] = None, static_seed: bool = False
@@ -2671,3 +2675,89 @@ def transform_observation_spec(
26712675
)
26722676
observation_spec["step_count"].space.minimum = 0
26732677
return observation_spec
2678+
2679+
2680+
class ExcludeTransform(Transform):
2681+
"""Excludes keys from the input tensordict.
2682+
2683+
Args:
2684+
*excluded_keys (iterable of str): The name of the keys to exclude. If the key is
2685+
not present, it is simply ignored.
2686+
2687+
"""
2688+
2689+
inplace = False
2690+
2691+
def __init__(self, *excluded_keys):
2692+
super().__init__(in_keys=[], in_keys_inv=[], out_keys=[], out_keys_inv=[])
2693+
if not all(isinstance(item, str) for item in excluded_keys):
2694+
raise ValueError("excluded_keys must be a list or tuple of strings.")
2695+
self.excluded_keys = excluded_keys
2696+
if "reward" in excluded_keys:
2697+
raise RuntimeError("'reward' cannot be excluded from the keys.")
2698+
2699+
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
2700+
return tensordict.exclude(*self.excluded_keys)
2701+
2702+
def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
2703+
return tensordict.exclude(*self.excluded_keys)
2704+
2705+
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
2706+
if any(key in observation_spec.keys() for key in self.excluded_keys):
2707+
return CompositeSpec(
2708+
**{
2709+
key: value
2710+
for key, value in observation_spec.items()
2711+
if key not in self.excluded_keys
2712+
}
2713+
)
2714+
return observation_spec
2715+
2716+
2717+
class SelectTransform(Transform):
2718+
"""Select keys from the input tensordict.
2719+
2720+
In general, the :obj:`ExcludeTransform` should be preferred: this transforms also
2721+
selects the "action" (or other keys from input_spec), "done" and "reward"
2722+
keys but other may be necessary.
2723+
2724+
Args:
2725+
*selected_keys (iterable of str): The name of the keys to select. If the key is
2726+
not present, it is simply ignored.
2727+
2728+
"""
2729+
2730+
inplace = False
2731+
2732+
def __init__(self, *selected_keys):
2733+
super().__init__(in_keys=[], in_keys_inv=[], out_keys=[], out_keys_inv=[])
2734+
if not all(isinstance(item, str) for item in selected_keys):
2735+
raise ValueError("excluded_keys must be a list or tuple of strings.")
2736+
self.selected_keys = selected_keys
2737+
2738+
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
2739+
if self.parent:
2740+
input_keys = self.parent.input_spec.keys()
2741+
else:
2742+
input_keys = []
2743+
return tensordict.select(
2744+
*self.selected_keys, "reward", "done", *input_keys, strict=False
2745+
)
2746+
2747+
def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
2748+
if self.parent:
2749+
input_keys = self.parent.input_spec.keys()
2750+
else:
2751+
input_keys = []
2752+
return tensordict.select(
2753+
*self.selected_keys, "reward", "done", *input_keys, strict=False
2754+
)
2755+
2756+
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
2757+
return CompositeSpec(
2758+
**{
2759+
key: value
2760+
for key, value in observation_spec.items()
2761+
if key in self.selected_keys
2762+
}
2763+
)

0 commit comments

Comments
 (0)