@@ -8583,6 +8583,26 @@ def _indent(s):
8583
8583
f"\n { _indent (out_action_key )} ,\n { _indent (sampling )} ,\n { _indent (categorical )} )"
8584
8584
)
8585
8585
8586
+ def _custom_arange (self , nint , device ):
8587
+ result = torch .arange (
8588
+ start = 0.0 ,
8589
+ end = 1.0 ,
8590
+ step = 1 / nint ,
8591
+ dtype = self .dtype ,
8592
+ device = device ,
8593
+ )
8594
+ result_ = result
8595
+ if self .sampling in (
8596
+ self .SamplingStrategy .HIGH ,
8597
+ self .SamplingStrategy .MEDIAN ,
8598
+ ):
8599
+ result_ = (1 - result ).flip (0 )
8600
+ if self .sampling == self .SamplingStrategy .MEDIAN :
8601
+ result = (result + result_ ) / 2
8602
+ else :
8603
+ result = result_
8604
+ return result
8605
+
8586
8606
def transform_input_spec (self , input_spec ):
8587
8607
try :
8588
8608
action_spec = self .parent .full_action_spec_unbatched [self .in_keys_inv [0 ]]
@@ -8611,29 +8631,11 @@ def transform_input_spec(self, input_spec):
8611
8631
num_intervals = int (num_intervals .squeeze ())
8612
8632
self .num_intervals = torch .as_tensor (num_intervals )
8613
8633
8614
- def custom_arange (nint ):
8615
- result = torch .arange (
8616
- start = 0.0 ,
8617
- end = 1.0 ,
8618
- step = 1 / nint ,
8619
- dtype = self .dtype ,
8620
- device = action_spec .device ,
8621
- )
8622
- result_ = result
8623
- if self .sampling in (
8624
- self .SamplingStrategy .HIGH ,
8625
- self .SamplingStrategy .MEDIAN ,
8626
- ):
8627
- result_ = (1 - result ).flip (0 )
8628
- if self .sampling == self .SamplingStrategy .MEDIAN :
8629
- result = (result + result_ ) / 2
8630
- else :
8631
- result = result_
8632
- return result
8633
-
8634
8634
if isinstance (num_intervals , int ):
8635
8635
arange = (
8636
- custom_arange (num_intervals ).expand ((* n_act , num_intervals ))
8636
+ self ._custom_arange (num_intervals , action_spec .device ).expand (
8637
+ (* n_act , num_intervals )
8638
+ )
8637
8639
* interval
8638
8640
)
8639
8641
low = action_spec .low
@@ -8642,7 +8644,7 @@ def custom_arange(nint):
8642
8644
self .register_buffer ("intervals" , low + arange )
8643
8645
else :
8644
8646
arange = [
8645
- custom_arange (_num_intervals ) * interval
8647
+ self . _custom_arange (_num_intervals , action_spec . device ) * interval
8646
8648
for _num_intervals , interval in zip (
8647
8649
num_intervals .tolist (), interval .unbind (- 2 )
8648
8650
)
0 commit comments