@@ -2682,7 +2682,6 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
2682
2682
2683
2683
episode_specs = {}
2684
2684
if isinstance (reward_spec , CompositeSpec ):
2685
-
2686
2685
# If reward_spec is a CompositeSpec, all in_keys should be keys of reward_spec
2687
2686
if not all (k in reward_spec .keys () for k in self .in_keys ):
2688
2687
raise KeyError ("Not all in_keys are present in ´reward_spec´" )
@@ -2697,7 +2696,6 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
2697
2696
episode_specs .update ({out_key : episode_spec })
2698
2697
2699
2698
else :
2700
-
2701
2699
# If reward_spec is not a CompositeSpec, the only in_key should be ´reward´
2702
2700
if not set (self .in_keys ) == {"reward" }:
2703
2701
raise KeyError (
@@ -2882,3 +2880,106 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
2882
2880
if key in self .selected_keys
2883
2881
}
2884
2882
)
2883
+
2884
+
2885
+ class TimeMaxPool (Transform ):
2886
+ """Take the maximum value in each position over the last T observations.
2887
+
2888
+ This transform take the maximum value in each position for all in_keys tensors over the last T time steps.
2889
+
2890
+ Args:
2891
+ in_keys (sequence of str, optional): input keys on which the max pool will be applied. Defaults to "observation" if left empty.
2892
+ out_keys (sequence of str, optional): output keys where the output will be written. Defaults to `in_keys` if left empty.
2893
+ T (int, optional): Number of time steps over which to apply max pooling.
2894
+ """
2895
+
2896
+ inplace = False
2897
+ invertible = False
2898
+
2899
+ def __init__ (
2900
+ self ,
2901
+ in_keys : Optional [Sequence [str ]] = None ,
2902
+ out_keys : Optional [Sequence [str ]] = None ,
2903
+ T : int = 1 ,
2904
+ ):
2905
+ if in_keys is None :
2906
+ in_keys = ["observation" ]
2907
+ super ().__init__ (in_keys = in_keys , out_keys = out_keys )
2908
+ if T < 1 :
2909
+ raise ValueError (
2910
+ "TimeMaxPoolTranform T parameter should have a value greater or equal to one."
2911
+ )
2912
+ if len (self .in_keys ) != len (self .out_keys ):
2913
+ raise ValueError (
2914
+ "TimeMaxPoolTranform in_keys and out_keys don't have the same number of elements"
2915
+ )
2916
+ self .buffer_size = T
2917
+ for in_key in self .in_keys :
2918
+ buffer_name = f"_maxpool_buffer_{ in_key } "
2919
+ setattr (
2920
+ self ,
2921
+ buffer_name ,
2922
+ torch .nn .parameter .UninitializedBuffer (
2923
+ device = torch .device ("cpu" ), dtype = torch .get_default_dtype ()
2924
+ ),
2925
+ )
2926
+
2927
+ def reset (self , tensordict : TensorDictBase ) -> TensorDictBase :
2928
+ """Resets _buffers."""
2929
+ # Non-batched environments
2930
+ if len (tensordict .batch_size ) < 1 or tensordict .batch_size [0 ] == 1 :
2931
+ for in_key in self .in_keys :
2932
+ buffer_name = f"_maxpool_buffer_{ in_key } "
2933
+ buffer = getattr (self , buffer_name )
2934
+ if isinstance (buffer , torch .nn .parameter .UninitializedBuffer ):
2935
+ continue
2936
+ buffer .fill_ (0.0 )
2937
+
2938
+ # Batched environments
2939
+ else :
2940
+ _reset = tensordict .get (
2941
+ "_reset" ,
2942
+ torch .ones (
2943
+ tensordict .batch_size ,
2944
+ dtype = torch .bool ,
2945
+ device = tensordict .device ,
2946
+ ),
2947
+ )
2948
+ for in_key in self .in_keys :
2949
+ buffer_name = f"_maxpool_buffer_{ in_key } "
2950
+ buffer = getattr (self , buffer_name )
2951
+ if isinstance (buffer , torch .nn .parameter .UninitializedBuffer ):
2952
+ continue
2953
+ buffer [:, _reset ] = 0.0
2954
+
2955
+ return tensordict
2956
+
2957
+ def _make_missing_buffer (self , data , buffer_name ):
2958
+ buffer = getattr (self , buffer_name )
2959
+ buffer .materialize ((self .buffer_size ,) + data .shape )
2960
+ buffer = buffer .to (data .dtype ).to (data .device ).zero_ ()
2961
+ setattr (self , buffer_name , buffer )
2962
+
2963
+ def _call (self , tensordict : TensorDictBase ) -> TensorDictBase :
2964
+ """Update the episode tensordict with max pooled keys."""
2965
+ for in_key , out_key in zip (self .in_keys , self .out_keys ):
2966
+ # Lazy init of buffers
2967
+ buffer_name = f"_maxpool_buffer_{ in_key } "
2968
+ buffer = getattr (self , buffer_name )
2969
+ if isinstance (buffer , torch .nn .parameter .UninitializedBuffer ):
2970
+ data = tensordict [in_key ]
2971
+ self ._make_missing_buffer (data , buffer_name )
2972
+ # shift obs 1 position to the right
2973
+ buffer .copy_ (torch .roll (buffer , shifts = 1 , dims = 0 ))
2974
+ # add new obs
2975
+ buffer [0 ].copy_ (tensordict [in_key ])
2976
+ # apply max pooling
2977
+ pooled_tensor , _ = buffer .max (dim = 0 )
2978
+ # add to tensordict
2979
+ tensordict .set (out_key , pooled_tensor )
2980
+
2981
+ return tensordict
2982
+
2983
+ @_apply_to_composite
2984
+ def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
2985
+ return observation_spec
0 commit comments