@@ -2620,6 +2620,8 @@ class CatFrames(ObservationTransform):
2620
2620
reset indicator. Must be unique. If not provided, defaults to the
2621
2621
only reset key of the parent environment (if it has only one)
2622
2622
and raises an exception otherwise.
2623
+ done_key (NestedKey, optional): the done key to be used as partial
2624
+ done indicator. Must be unique. If not provided, defaults to ``"done"``.
2623
2625
2624
2626
Examples:
2625
2627
>>> from torchrl.envs.libs.gym import GymEnv
@@ -2700,6 +2702,7 @@ def __init__(
2700
2702
padding_value = 0 ,
2701
2703
as_inverse = False ,
2702
2704
reset_key : NestedKey | None = None ,
2705
+ done_key : NestedKey | None = None ,
2703
2706
):
2704
2707
if in_keys is None :
2705
2708
in_keys = IMAGE_KEYS
@@ -2733,6 +2736,19 @@ def __init__(
2733
2736
# keeps track of calls to _reset since it's only _call that will populate the buffer
2734
2737
self .as_inverse = as_inverse
2735
2738
self .reset_key = reset_key
2739
+ self .done_key = done_key
2740
+
2741
+ @property
2742
+ def done_key (self ):
2743
+ done_key = self .__dict__ .get ("_done_key" , None )
2744
+ if done_key is None :
2745
+ done_key = "done"
2746
+ self ._done_key = done_key
2747
+ return done_key
2748
+
2749
+ @done_key .setter
2750
+ def done_key (self , value ):
2751
+ self ._done_key = value
2736
2752
2737
2753
@property
2738
2754
def reset_key (self ):
@@ -2829,15 +2845,6 @@ def _call(self, tensordict: TensorDictBase, _reset=None) -> TensorDictBase:
2829
2845
# make linter happy. An exception has already been raised
2830
2846
raise NotImplementedError
2831
2847
2832
- # # this duplicates the code below, but only for _reset values
2833
- # if _all:
2834
- # buffer.copy_(torch.roll(buffer_reset, shifts=-d, dims=dim))
2835
- # buffer_reset = buffer
2836
- # else:
2837
- # buffer_reset = buffer[_reset] = torch.roll(
2838
- # buffer_reset, shifts=-d, dims=dim
2839
- # )
2840
- # add new obs
2841
2848
if self .dim < 0 :
2842
2849
n = buffer_reset .ndimension () + self .dim
2843
2850
else :
@@ -2906,69 +2913,136 @@ def unfolding(self, tensordict: TensorDictBase) -> TensorDictBase:
2906
2913
if i != tensordict .ndim - 1 :
2907
2914
tensordict = tensordict .transpose (tensordict .ndim - 1 , i )
2908
2915
# first sort the in_keys with strings and non-strings
2909
- in_keys = list (
2910
- zip (
2911
- (in_key , out_key )
2912
- for in_key , out_key in zip (self .in_keys , self .out_keys )
2913
- if isinstance (in_key , str ) or len (in_key ) == 1
2914
- )
2915
- )
2916
- in_keys += list (
2917
- zip (
2918
- (in_key , out_key )
2919
- for in_key , out_key in zip (self .in_keys , self .out_keys )
2920
- if not isinstance (in_key , str ) and not len (in_key ) == 1
2916
+ keys = [
2917
+ (in_key , out_key )
2918
+ for in_key , out_key in zip (self .in_keys , self .out_keys )
2919
+ if isinstance (in_key , str )
2920
+ ]
2921
+ keys += [
2922
+ (in_key , out_key )
2923
+ for in_key , out_key in zip (self .in_keys , self .out_keys )
2924
+ if not isinstance (in_key , str )
2925
+ ]
2926
+
2927
+ def unfold_done (done , N ):
2928
+ prefix = (slice (None ),) * (tensordict .ndim - 1 )
2929
+ reset = torch .cat (
2930
+ [
2931
+ torch .zeros_like (done [prefix + (slice (self .N - 1 ),)]),
2932
+ torch .ones_like (done [prefix + (slice (1 ),)]),
2933
+ done [prefix + (slice (None , - 1 ),)],
2934
+ ],
2935
+ tensordict .ndim - 1 ,
2921
2936
)
2922
- )
2923
- for in_key , out_key in zip (self .in_keys , self .out_keys ):
2937
+ reset_unfold = reset .unfold (tensordict .ndim - 1 , self .N , 1 )
2938
+ reset_unfold_slice = reset_unfold [..., - 1 ]
2939
+ reset_unfold_list = [torch .zeros_like (reset_unfold_slice )]
2940
+ for r in reversed (reset_unfold .unbind (- 1 )):
2941
+ reset_unfold_list .append (r | reset_unfold_list [- 1 ])
2942
+ reset_unfold_slice = reset_unfold_list [- 1 ]
2943
+ reset_unfold = torch .stack (list (reversed (reset_unfold_list ))[1 :], - 1 )
2944
+ reset = reset [prefix + (slice (self .N - 1 , None ),)]
2945
+ reset [prefix + (0 ,)] = 1
2946
+ return reset_unfold , reset
2947
+
2948
+ done = tensordict .get (("next" , self .done_key ))
2949
+ done_mask , reset = unfold_done (done , self .N )
2950
+
2951
+ for in_key , out_key in keys :
2924
2952
# check if we have an obs in "next" that has already been processed.
2925
2953
# If so, we must add an offset
2926
- data = tensordict .get (in_key )
2954
+ data_orig = data = tensordict .get (in_key )
2955
+ n_feat = data_orig .shape [data .ndim + self .dim ]
2956
+ first_val = None
2927
2957
if isinstance (in_key , tuple ) and in_key [0 ] == "next" :
2928
2958
# let's get the out_key we have already processed
2929
- prev_out_key = dict (zip (self .in_keys , self .out_keys ))[in_key [1 ]]
2930
- prev_val = tensordict .get (prev_out_key )
2931
- # the first item is located along `dim+1` at the last index of the
2932
- # first time index
2933
- idx = (
2934
- [slice (None )] * (tensordict .ndim - 1 )
2935
- + [0 ]
2936
- + [..., - 1 ]
2937
- + [slice (None )] * (abs (self .dim ) - 1 )
2959
+ prev_out_key = dict (zip (self .in_keys , self .out_keys )).get (
2960
+ in_key [1 ], None
2938
2961
)
2939
- first_val = prev_val [tuple (idx )].unsqueeze (tensordict .ndim - 1 )
2940
- data0 = [first_val ] * (self .N - 1 )
2941
- if self .padding == "constant" :
2942
- data0 = [
2943
- torch .full_like (elt , self .padding_value ) for elt in data0 [:- 1 ]
2944
- ] + data0 [- 1 :]
2945
- elif self .padding == "same" :
2946
- pass
2947
- else :
2948
- # make linter happy. An exception has already been raised
2949
- raise NotImplementedError
2950
- elif self .padding == "same" :
2951
- idx = [slice (None )] * (tensordict .ndim - 1 ) + [0 ]
2952
- data0 = [data [tuple (idx )].unsqueeze (tensordict .ndim - 1 )] * (self .N - 1 )
2953
- elif self .padding == "constant" :
2954
- idx = [slice (None )] * (tensordict .ndim - 1 ) + [0 ]
2955
- data0 = [
2956
- torch .full_like (data [tuple (idx )], self .padding_value ).unsqueeze (
2957
- tensordict .ndim - 1
2962
+ if prev_out_key is not None :
2963
+ prev_val = tensordict .get (prev_out_key )
2964
+ # n_feat = prev_val.shape[data.ndim + self.dim] // self.N
2965
+ first_val = prev_val .unflatten (
2966
+ data .ndim + self .dim , (self .N , n_feat )
2958
2967
)
2959
- ] * (self .N - 1 )
2960
- else :
2961
- # make linter happy. An exception has already been raised
2962
- raise NotImplementedError
2968
+
2969
+ idx = [slice (None )] * (tensordict .ndim - 1 ) + [0 ]
2970
+ data0 = [
2971
+ torch .full_like (data [tuple (idx )], self .padding_value ).unsqueeze (
2972
+ tensordict .ndim - 1
2973
+ )
2974
+ ] * (self .N - 1 )
2963
2975
2964
2976
data = torch .cat (data0 + [data ], tensordict .ndim - 1 )
2965
2977
2966
2978
data = data .unfold (tensordict .ndim - 1 , self .N , 1 )
2979
+
2980
+ # Place -1 dim at self.dim place before squashing
2981
+ done_mask_expand = expand_as_right (done_mask , data )
2967
2982
data = data .permute (
2968
- * range (0 , data .ndim + self .dim ),
2983
+ * range (0 , data .ndim + self .dim - 1 ),
2984
+ - 1 ,
2985
+ * range (data .ndim + self .dim - 1 , data .ndim - 1 ),
2986
+ )
2987
+ done_mask_expand = done_mask_expand .permute (
2988
+ * range (0 , done_mask_expand .ndim + self .dim - 1 ),
2969
2989
- 1 ,
2970
- * range (data .ndim + self .dim , data .ndim - 1 ),
2990
+ * range (done_mask_expand .ndim + self .dim - 1 , done_mask_expand .ndim - 1 ),
2971
2991
)
2992
+ if self .padding != "same" :
2993
+ data = torch .where (done_mask_expand , self .padding_value , data )
2994
+ else :
2995
+ # TODO: This is a pretty bad implementation, could be
2996
+ # made more efficient but it works!
2997
+ reset_vals = list (data_orig [reset .squeeze (- 1 )].unbind (0 ))
2998
+ j_ = float ("inf" )
2999
+ reps = []
3000
+ d = data .ndim + self .dim - 1
3001
+ for j in done_mask_expand .sum (d ).sum (d ).view (- 1 ) // n_feat :
3002
+ if j > j_ :
3003
+ reset_vals = reset_vals [1 :]
3004
+ reps .extend ([reset_vals [0 ]] * int (j ))
3005
+ j_ = j
3006
+ reps = torch .stack (reps )
3007
+ data = torch .masked_scatter (data , done_mask_expand , reps .reshape (- 1 ))
3008
+
3009
+ if first_val is not None :
3010
+ # Aggregate reset along last dim
3011
+ reset = reset .any (- 1 , True )
3012
+ rexp = reset .expand (* reset .shape [:- 1 ], n_feat )
3013
+ rexp = torch .cat (
3014
+ [
3015
+ torch .zeros_like (
3016
+ data0 [0 ].repeat_interleave (
3017
+ len (data0 ), dim = tensordict .ndim - 1
3018
+ ),
3019
+ dtype = torch .bool ,
3020
+ ),
3021
+ rexp ,
3022
+ ],
3023
+ tensordict .ndim - 1 ,
3024
+ )
3025
+ rexp = rexp .unfold (tensordict .ndim - 1 , self .N , 1 )
3026
+ rexp_orig = rexp
3027
+ rexp = torch .cat ([rexp [..., 1 :], torch .zeros_like (rexp [..., - 1 :])], - 1 )
3028
+ if self .padding == "same" :
3029
+ rexp_orig = rexp_orig .flip (- 1 ).cumsum (- 1 ).flip (- 1 ).bool ()
3030
+ rexp = rexp .flip (- 1 ).cumsum (- 1 ).flip (- 1 ).bool ()
3031
+ rexp_orig = torch .cat (
3032
+ [torch .zeros_like (rexp_orig [..., - 1 :]), rexp_orig [..., 1 :]], - 1
3033
+ )
3034
+ rexp = rexp .permute (
3035
+ * range (0 , rexp .ndim + self .dim - 1 ),
3036
+ - 1 ,
3037
+ * range (rexp .ndim + self .dim - 1 , rexp .ndim - 1 ),
3038
+ )
3039
+ rexp_orig = rexp_orig .permute (
3040
+ * range (0 , rexp_orig .ndim + self .dim - 1 ),
3041
+ - 1 ,
3042
+ * range (rexp_orig .ndim + self .dim - 1 , rexp_orig .ndim - 1 ),
3043
+ )
3044
+ data [rexp ] = first_val [rexp_orig ]
3045
+ data = data .flatten (data .ndim + self .dim - 1 , data .ndim + self .dim )
2972
3046
tensordict .set (out_key , data )
2973
3047
if tensordict_orig is not tensordict :
2974
3048
tensordict_orig = tensordict .transpose (tensordict .ndim - 1 , i )
0 commit comments