Skip to content

Commit 985f5d1

Browse files
authored
[Feature] Max pool Transform (#841)
1 parent bdf2bcd commit 985f5d1

File tree

5 files changed

+135
-4
lines changed

5 files changed

+135
-4
lines changed

docs/source/reference/envs.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ in the environment. The keys to be included in this inverse transform are passed
229229
SqueezeTransform
230230
StepCounter
231231
TensorDictPrimer
232+
TimeMaxPool
232233
ToTensorImage
233234
UnsqueezeTransform
234235
VecNorm

test/test_transforms.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
SqueezeTransform,
5757
StepCounter,
5858
TensorDictPrimer,
59+
TimeMaxPool,
5960
ToTensorImage,
6061
TransformedEnv,
6162
UnsqueezeTransform,
@@ -106,7 +107,6 @@ def _test_vecnorm_subproc_auto(
106107

107108
@pytest.mark.parametrize("nprc", [2, 5])
108109
def test_vecnorm_parallel_auto(self, nprc):
109-
110110
queues = []
111111
prcs = []
112112
if _has_gym:
@@ -864,6 +864,34 @@ def test_sum_reward(self, keys, device):
864864
assert "some_extra_observation" in transformed_observation_spec2.keys()
865865
assert "episode_reward" in transformed_observation_spec2.keys()
866866

867+
@pytest.mark.parametrize("T", [2, 4])
868+
@pytest.mark.parametrize("seq_len", [8])
869+
@pytest.mark.parametrize("device", get_available_devices())
870+
def test_time_max_pool(self, T, seq_len, device):
871+
batch = 1
872+
nodes = 4
873+
keys = ["observation"]
874+
time_max_pool = TimeMaxPool(keys, T=T)
875+
876+
tensor_list = []
877+
for _ in range(seq_len):
878+
tensor_list.append(torch.rand(batch, nodes).to(device))
879+
max_vals, _ = torch.max(torch.stack(tensor_list[-T:]), dim=0)
880+
881+
print(f"max vals: {max_vals}")
882+
883+
for i in range(seq_len):
884+
env_td = TensorDict(
885+
{
886+
"observation": tensor_list[i],
887+
},
888+
device=device,
889+
batch_size=[batch],
890+
)
891+
transformed_td = time_max_pool(env_td)
892+
893+
assert (max_vals == transformed_td["observation"]).all()
894+
867895
@pytest.mark.parametrize("batch", [[], [1], [3, 2]])
868896
@pytest.mark.parametrize(
869897
"keys",
@@ -1667,7 +1695,6 @@ def test_append(self):
16671695
assert obs_spec.shape[-1] == 4 * env.base_env.observation_spec[key].shape[-1]
16681696

16691697
def test_insert(self):
1670-
16711698
env = ContinuousActionVecMockEnv()
16721699
obs_spec = env.observation_spec
16731700
(key,) = itertools.islice(obs_spec.keys(), 1)

torchrl/envs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
SqueezeTransform,
3535
StepCounter,
3636
TensorDictPrimer,
37+
TimeMaxPool,
3738
ToTensorImage,
3839
Transform,
3940
TransformedEnv,

torchrl/envs/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
SqueezeTransform,
3131
StepCounter,
3232
TensorDictPrimer,
33+
TimeMaxPool,
3334
ToTensorImage,
3435
Transform,
3536
TransformedEnv,

torchrl/envs/transforms/transforms.py

Lines changed: 103 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2682,7 +2682,6 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
26822682

26832683
episode_specs = {}
26842684
if isinstance(reward_spec, CompositeSpec):
2685-
26862685
# If reward_spec is a CompositeSpec, all in_keys should be keys of reward_spec
26872686
if not all(k in reward_spec.keys() for k in self.in_keys):
26882687
raise KeyError("Not all in_keys are present in ´reward_spec´")
@@ -2697,7 +2696,6 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
26972696
episode_specs.update({out_key: episode_spec})
26982697

26992698
else:
2700-
27012699
# If reward_spec is not a CompositeSpec, the only in_key should be ´reward´
27022700
if not set(self.in_keys) == {"reward"}:
27032701
raise KeyError(
@@ -2882,3 +2880,106 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
28822880
if key in self.selected_keys
28832881
}
28842882
)
2883+
2884+
2885+
class TimeMaxPool(Transform):
2886+
"""Take the maximum value in each position over the last T observations.
2887+
2888+
This transform take the maximum value in each position for all in_keys tensors over the last T time steps.
2889+
2890+
Args:
2891+
in_keys (sequence of str, optional): input keys on which the max pool will be applied. Defaults to "observation" if left empty.
2892+
out_keys (sequence of str, optional): output keys where the output will be written. Defaults to `in_keys` if left empty.
2893+
T (int, optional): Number of time steps over which to apply max pooling.
2894+
"""
2895+
2896+
inplace = False
2897+
invertible = False
2898+
2899+
def __init__(
2900+
self,
2901+
in_keys: Optional[Sequence[str]] = None,
2902+
out_keys: Optional[Sequence[str]] = None,
2903+
T: int = 1,
2904+
):
2905+
if in_keys is None:
2906+
in_keys = ["observation"]
2907+
super().__init__(in_keys=in_keys, out_keys=out_keys)
2908+
if T < 1:
2909+
raise ValueError(
2910+
"TimeMaxPoolTranform T parameter should have a value greater or equal to one."
2911+
)
2912+
if len(self.in_keys) != len(self.out_keys):
2913+
raise ValueError(
2914+
"TimeMaxPoolTranform in_keys and out_keys don't have the same number of elements"
2915+
)
2916+
self.buffer_size = T
2917+
for in_key in self.in_keys:
2918+
buffer_name = f"_maxpool_buffer_{in_key}"
2919+
setattr(
2920+
self,
2921+
buffer_name,
2922+
torch.nn.parameter.UninitializedBuffer(
2923+
device=torch.device("cpu"), dtype=torch.get_default_dtype()
2924+
),
2925+
)
2926+
2927+
def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
2928+
"""Resets _buffers."""
2929+
# Non-batched environments
2930+
if len(tensordict.batch_size) < 1 or tensordict.batch_size[0] == 1:
2931+
for in_key in self.in_keys:
2932+
buffer_name = f"_maxpool_buffer_{in_key}"
2933+
buffer = getattr(self, buffer_name)
2934+
if isinstance(buffer, torch.nn.parameter.UninitializedBuffer):
2935+
continue
2936+
buffer.fill_(0.0)
2937+
2938+
# Batched environments
2939+
else:
2940+
_reset = tensordict.get(
2941+
"_reset",
2942+
torch.ones(
2943+
tensordict.batch_size,
2944+
dtype=torch.bool,
2945+
device=tensordict.device,
2946+
),
2947+
)
2948+
for in_key in self.in_keys:
2949+
buffer_name = f"_maxpool_buffer_{in_key}"
2950+
buffer = getattr(self, buffer_name)
2951+
if isinstance(buffer, torch.nn.parameter.UninitializedBuffer):
2952+
continue
2953+
buffer[:, _reset] = 0.0
2954+
2955+
return tensordict
2956+
2957+
def _make_missing_buffer(self, data, buffer_name):
2958+
buffer = getattr(self, buffer_name)
2959+
buffer.materialize((self.buffer_size,) + data.shape)
2960+
buffer = buffer.to(data.dtype).to(data.device).zero_()
2961+
setattr(self, buffer_name, buffer)
2962+
2963+
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
2964+
"""Update the episode tensordict with max pooled keys."""
2965+
for in_key, out_key in zip(self.in_keys, self.out_keys):
2966+
# Lazy init of buffers
2967+
buffer_name = f"_maxpool_buffer_{in_key}"
2968+
buffer = getattr(self, buffer_name)
2969+
if isinstance(buffer, torch.nn.parameter.UninitializedBuffer):
2970+
data = tensordict[in_key]
2971+
self._make_missing_buffer(data, buffer_name)
2972+
# shift obs 1 position to the right
2973+
buffer.copy_(torch.roll(buffer, shifts=1, dims=0))
2974+
# add new obs
2975+
buffer[0].copy_(tensordict[in_key])
2976+
# apply max pooling
2977+
pooled_tensor, _ = buffer.max(dim=0)
2978+
# add to tensordict
2979+
tensordict.set(out_key, pooled_tensor)
2980+
2981+
return tensordict
2982+
2983+
@_apply_to_composite
2984+
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
2985+
return observation_spec

0 commit comments

Comments
 (0)