@@ -662,7 +662,7 @@ def __init__(self, *, priority_key: str = "td_error", **kw) -> None:
662
662
super ().__init__ (** kw )
663
663
self .priority_key = priority_key
664
664
665
- def _get_priority (self , tensordict : TensorDictBase ) -> Optional [ torch . Tensor ] :
665
+ def _get_priority_item (self , tensordict : TensorDictBase ) -> float :
666
666
if "_data" in tensordict .keys ():
667
667
tensordict = tensordict .get ("_data" )
668
668
@@ -682,6 +682,23 @@ def _get_priority(self, tensordict: TensorDictBase) -> Optional[torch.Tensor]:
682
682
)
683
683
return priority
684
684
685
+ def _get_priority_vector (self , tensordict : TensorDictBase ) -> torch .Tensor :
686
+ if "_data" in tensordict .keys ():
687
+ tensordict = tensordict .get ("_data" )
688
+
689
+ priority = tensordict .get (self .priority_key , None )
690
+ if priority is None :
691
+ return torch .tensor (
692
+ self ._sampler .default_priority ,
693
+ dtype = torch .float ,
694
+ device = tensordict .device ,
695
+ ).expand (tensordict .shape [0 ])
696
+
697
+ priority = priority .reshape (priority .shape [0 ], - 1 )
698
+ priority = _reduce (priority , self ._sampler .reduction , dim = 1 )
699
+
700
+ return priority
701
+
685
702
def add (self , data : TensorDictBase ) -> int :
686
703
if self ._transform is not None :
687
704
data = self ._transform .inv (data )
@@ -709,61 +726,50 @@ def add(self, data: TensorDictBase) -> int:
709
726
self .update_tensordict_priority (data_add )
710
727
return index
711
728
712
- def extend (self , tensordicts : Union [List , TensorDictBase ]) -> torch .Tensor :
713
- if is_tensor_collection (tensordicts ):
714
- tensordicts = TensorDict (
715
- {"_data" : tensordicts },
716
- batch_size = tensordicts .batch_size [:1 ],
717
- )
718
- if tensordicts .batch_dims > 1 :
719
- # we want the tensordict to have one dimension only. The batch size
720
- # of the sampled tensordicts can be changed thereafter
721
- if not isinstance (tensordicts , LazyStackedTensorDict ):
722
- tensordicts = tensordicts .clone (recurse = False )
723
- else :
724
- tensordicts = tensordicts .contiguous ()
725
- # we keep track of the batch size to reinstantiate it when sampling
726
- if "_rb_batch_size" in tensordicts .keys ():
727
- raise KeyError (
728
- "conflicting key '_rb_batch_size'. Consider removing from data."
729
- )
730
- shape = torch .tensor (tensordicts .batch_size [1 :]).expand (
731
- tensordicts .batch_size [0 ], tensordicts .batch_dims - 1
729
+ def extend (self , tensordicts : TensorDictBase ) -> torch .Tensor :
730
+
731
+ tensordicts = TensorDict (
732
+ {"_data" : tensordicts },
733
+ batch_size = tensordicts .batch_size [:1 ],
734
+ )
735
+ if tensordicts .batch_dims > 1 :
736
+ # we want the tensordict to have one dimension only. The batch size
737
+ # of the sampled tensordicts can be changed thereafter
738
+ if not isinstance (tensordicts , LazyStackedTensorDict ):
739
+ tensordicts = tensordicts .clone (recurse = False )
740
+ else :
741
+ tensordicts = tensordicts .contiguous ()
742
+ # we keep track of the batch size to reinstantiate it when sampling
743
+ if "_rb_batch_size" in tensordicts .keys ():
744
+ raise KeyError (
745
+ "conflicting key '_rb_batch_size'. Consider removing from data."
732
746
)
733
- tensordicts .set ("_rb_batch_size" , shape )
734
- tensordicts .set (
735
- "index" ,
736
- torch .zeros (
737
- tensordicts .shape , device = tensordicts .device , dtype = torch .int
738
- ),
747
+ shape = torch .tensor (tensordicts .batch_size [1 :]).expand (
748
+ tensordicts .batch_size [0 ], tensordicts .batch_dims - 1
739
749
)
740
-
741
- if not is_tensor_collection ( tensordicts ):
742
- stacked_td = torch . stack ( tensordicts , 0 )
743
- else :
744
- stacked_td = tensordicts
750
+ tensordicts . set ( "_rb_batch_size" , shape )
751
+ tensordicts . set (
752
+ "index" ,
753
+ torch . zeros ( tensordicts . shape , device = tensordicts . device , dtype = torch . int ),
754
+ )
745
755
746
756
if self ._transform is not None :
747
- tensordicts = self ._transform .inv (stacked_td .get ("_data" ))
748
- stacked_td .set ("_data" , tensordicts )
749
- if tensordicts .device is not None :
750
- stacked_td = stacked_td .to (tensordicts .device )
757
+ data = self ._transform .inv (tensordicts .get ("_data" ))
758
+ tensordicts .set ("_data" , data )
759
+ if data .device is not None :
760
+ tensordicts = tensordicts .to (data .device )
751
761
752
- index = super ()._extend (stacked_td )
753
- self .update_tensordict_priority (stacked_td )
762
+ index = super ()._extend (tensordicts )
763
+ self .update_tensordict_priority (tensordicts )
754
764
return index
755
765
756
766
def update_tensordict_priority (self , data : TensorDictBase ) -> None :
757
767
if not isinstance (self ._sampler , PrioritizedSampler ):
758
768
return
759
769
if data .ndim :
760
- priority = torch .tensor (
761
- [self ._get_priority (td ) for td in data ],
762
- dtype = torch .float ,
763
- device = data .device ,
764
- )
770
+ priority = self ._get_priority_vector (data )
765
771
else :
766
- priority = self ._get_priority (data )
772
+ priority = self ._get_priority_item (data )
767
773
index = data .get ("index" )
768
774
while index .shape != priority .shape :
769
775
# reduce index
@@ -1010,17 +1016,23 @@ def __call__(self, list_of_tds):
1010
1016
return self .out
1011
1017
1012
1018
1013
- def _reduce (tensor : torch .Tensor , reduction : str ):
1019
+ def _reduce (
1020
+ tensor : torch .Tensor , reduction : str , dim : Optional [int ] = None
1021
+ ) -> Union [float , torch .Tensor ]:
1014
1022
"""Reduces a tensor given the reduction method."""
1015
1023
if reduction == "max" :
1016
- return tensor .max (). item ( )
1024
+ result = tensor .max (dim = dim )
1017
1025
elif reduction == "min" :
1018
- return tensor .min (). item ( )
1026
+ result = tensor .min (dim = dim )
1019
1027
elif reduction == "mean" :
1020
- return tensor .mean (). item ( )
1028
+ result = tensor .mean (dim = dim )
1021
1029
elif reduction == "median" :
1022
- return tensor .median ().item ()
1023
- raise NotImplementedError (f"Unknown reduction method { reduction } " )
1030
+ result = tensor .median (dim = dim )
1031
+ else :
1032
+ raise NotImplementedError (f"Unknown reduction method { reduction } " )
1033
+ if isinstance (result , tuple ):
1034
+ result = result [0 ]
1035
+ return result .item () if dim is None else result
1024
1036
1025
1037
1026
1038
def stack_tensors (list_of_tensor_iterators : List ) -> Tuple [torch .Tensor ]:
0 commit comments