@@ -112,10 +112,10 @@ def __init__(
112
112
113
113
super ().__init__ ()
114
114
115
- self .register_buffer ("eps_init" , torch .as_tensor ([ eps_init ] ))
116
- self .register_buffer ("eps_end" , torch .as_tensor ([ eps_end ] ))
115
+ self .register_buffer ("eps_init" , torch .as_tensor (eps_init ))
116
+ self .register_buffer ("eps_end" , torch .as_tensor (eps_end ))
117
117
self .annealing_num_steps = annealing_num_steps
118
- self .register_buffer ("eps" , torch .as_tensor ([ eps_init ] , dtype = torch .float32 ))
118
+ self .register_buffer ("eps" , torch .as_tensor (eps_init , dtype = torch .float32 ))
119
119
120
120
if spec is not None :
121
121
if not isinstance (spec , Composite ) and len (self .out_keys ) >= 1 :
@@ -275,13 +275,13 @@ def __init__(
275
275
super ().__init__ (policy )
276
276
if sigma_end > sigma_init :
277
277
raise RuntimeError ("sigma should decrease over time or be constant" )
278
- self .register_buffer ("sigma_init" , torch .tensor ([ sigma_init ] , device = device ))
279
- self .register_buffer ("sigma_end" , torch .tensor ([ sigma_end ] , device = device ))
278
+ self .register_buffer ("sigma_init" , torch .tensor (sigma_init , device = device ))
279
+ self .register_buffer ("sigma_end" , torch .tensor (sigma_end , device = device ))
280
280
self .annealing_num_steps = annealing_num_steps
281
- self .register_buffer ("mean" , torch .tensor ([ mean ] , device = device ))
282
- self .register_buffer ("std" , torch .tensor ([ std ] , device = device ))
281
+ self .register_buffer ("mean" , torch .tensor (mean , device = device ))
282
+ self .register_buffer ("std" , torch .tensor (std , device = device ))
283
283
self .register_buffer (
284
- "sigma" , torch .tensor ([ sigma_init ] , dtype = torch .float32 , device = device )
284
+ "sigma" , torch .tensor (sigma_init , dtype = torch .float32 , device = device )
285
285
)
286
286
self .action_key = action_key
287
287
self .out_keys = list (self .td_module .out_keys )
@@ -423,13 +423,13 @@ def __init__(
423
423
424
424
super ().__init__ ()
425
425
426
- self .register_buffer ("sigma_init" , torch .tensor ([ sigma_init ] , device = device ))
427
- self .register_buffer ("sigma_end" , torch .tensor ([ sigma_end ] , device = device ))
426
+ self .register_buffer ("sigma_init" , torch .tensor (sigma_init , device = device ))
427
+ self .register_buffer ("sigma_end" , torch .tensor (sigma_end , device = device ))
428
428
self .annealing_num_steps = annealing_num_steps
429
- self .register_buffer ("mean" , torch .tensor ([ mean ] , device = device ))
430
- self .register_buffer ("std" , torch .tensor ([ std ] , device = device ))
429
+ self .register_buffer ("mean" , torch .tensor (mean , device = device ))
430
+ self .register_buffer ("std" , torch .tensor (std , device = device ))
431
431
self .register_buffer (
432
- "sigma" , torch .tensor ([ sigma_init ] , dtype = torch .float32 , device = device )
432
+ "sigma" , torch .tensor (sigma_init , dtype = torch .float32 , device = device )
433
433
)
434
434
435
435
if spec is not None :
@@ -628,16 +628,16 @@ def __init__(
628
628
key = action_key ,
629
629
device = device ,
630
630
)
631
- self .register_buffer ("eps_init" , torch .tensor ([ eps_init ] , device = device ))
632
- self .register_buffer ("eps_end" , torch .tensor ([ eps_end ] , device = device ))
631
+ self .register_buffer ("eps_init" , torch .tensor (eps_init , device = device ))
632
+ self .register_buffer ("eps_end" , torch .tensor (eps_end , device = device ))
633
633
if self .eps_end > self .eps_init :
634
634
raise ValueError (
635
635
"eps should decrease over time or be constant, "
636
636
f"got eps_init={ eps_init } and eps_end={ eps_end } "
637
637
)
638
638
self .annealing_num_steps = annealing_num_steps
639
639
self .register_buffer (
640
- "eps" , torch .tensor ([ eps_init ] , dtype = torch .float32 , device = device )
640
+ "eps" , torch .tensor (eps_init , dtype = torch .float32 , device = device )
641
641
)
642
642
self .out_keys = list (self .td_module .out_keys ) + self .ou .out_keys
643
643
self .is_init_key = is_init_key
@@ -840,16 +840,16 @@ def __init__(
840
840
device = device ,
841
841
)
842
842
843
- self .register_buffer ("eps_init" , torch .tensor ([ eps_init ] , device = device ))
844
- self .register_buffer ("eps_end" , torch .tensor ([ eps_end ] , device = device ))
843
+ self .register_buffer ("eps_init" , torch .tensor (eps_init , device = device ))
844
+ self .register_buffer ("eps_end" , torch .tensor (eps_end , device = device ))
845
845
if self .eps_end > self .eps_init :
846
846
raise ValueError (
847
847
"eps should decrease over time or be constant, "
848
848
f"got eps_init={ eps_init } and eps_end={ eps_end } "
849
849
)
850
850
self .annealing_num_steps = annealing_num_steps
851
851
self .register_buffer (
852
- "eps" , torch .tensor ([ eps_init ] , dtype = torch .float32 , device = device )
852
+ "eps" , torch .tensor (eps_init , dtype = torch .float32 , device = device )
853
853
)
854
854
855
855
self .in_keys = [self .ou .key ]
0 commit comments