Skip to content

Commit a47b32c

Browse files
author
Vincent Moens
committed
[BugFix] make buffers zero-dim in exploration modules
ghstack-source-id: fd2705e Pull Request resolved: #2591
1 parent 14b2775 commit a47b32c

File tree

2 files changed

+21
-21
lines changed

2 files changed

+21
-21
lines changed

test/_utils_internal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,8 @@ def get_available_devices():
167167
def get_default_devices():
168168
num_cuda = torch.cuda.device_count()
169169
if num_cuda == 0:
170-
if torch.mps.is_available():
171-
return [torch.device("mps:0")]
170+
# if torch.mps.is_available():
171+
# return [torch.device("mps:0")]
172172
return [torch.device("cpu")]
173173
elif num_cuda == 1:
174174
return [torch.device("cuda:0")]

torchrl/modules/tensordict_module/exploration.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,10 @@ def __init__(
112112

113113
super().__init__()
114114

115-
self.register_buffer("eps_init", torch.as_tensor([eps_init]))
116-
self.register_buffer("eps_end", torch.as_tensor([eps_end]))
115+
self.register_buffer("eps_init", torch.as_tensor(eps_init))
116+
self.register_buffer("eps_end", torch.as_tensor(eps_end))
117117
self.annealing_num_steps = annealing_num_steps
118-
self.register_buffer("eps", torch.as_tensor([eps_init], dtype=torch.float32))
118+
self.register_buffer("eps", torch.as_tensor(eps_init, dtype=torch.float32))
119119

120120
if spec is not None:
121121
if not isinstance(spec, Composite) and len(self.out_keys) >= 1:
@@ -275,13 +275,13 @@ def __init__(
275275
super().__init__(policy)
276276
if sigma_end > sigma_init:
277277
raise RuntimeError("sigma should decrease over time or be constant")
278-
self.register_buffer("sigma_init", torch.tensor([sigma_init], device=device))
279-
self.register_buffer("sigma_end", torch.tensor([sigma_end], device=device))
278+
self.register_buffer("sigma_init", torch.tensor(sigma_init, device=device))
279+
self.register_buffer("sigma_end", torch.tensor(sigma_end, device=device))
280280
self.annealing_num_steps = annealing_num_steps
281-
self.register_buffer("mean", torch.tensor([mean], device=device))
282-
self.register_buffer("std", torch.tensor([std], device=device))
281+
self.register_buffer("mean", torch.tensor(mean, device=device))
282+
self.register_buffer("std", torch.tensor(std, device=device))
283283
self.register_buffer(
284-
"sigma", torch.tensor([sigma_init], dtype=torch.float32, device=device)
284+
"sigma", torch.tensor(sigma_init, dtype=torch.float32, device=device)
285285
)
286286
self.action_key = action_key
287287
self.out_keys = list(self.td_module.out_keys)
@@ -423,13 +423,13 @@ def __init__(
423423

424424
super().__init__()
425425

426-
self.register_buffer("sigma_init", torch.tensor([sigma_init], device=device))
427-
self.register_buffer("sigma_end", torch.tensor([sigma_end], device=device))
426+
self.register_buffer("sigma_init", torch.tensor(sigma_init, device=device))
427+
self.register_buffer("sigma_end", torch.tensor(sigma_end, device=device))
428428
self.annealing_num_steps = annealing_num_steps
429-
self.register_buffer("mean", torch.tensor([mean], device=device))
430-
self.register_buffer("std", torch.tensor([std], device=device))
429+
self.register_buffer("mean", torch.tensor(mean, device=device))
430+
self.register_buffer("std", torch.tensor(std, device=device))
431431
self.register_buffer(
432-
"sigma", torch.tensor([sigma_init], dtype=torch.float32, device=device)
432+
"sigma", torch.tensor(sigma_init, dtype=torch.float32, device=device)
433433
)
434434

435435
if spec is not None:
@@ -628,16 +628,16 @@ def __init__(
628628
key=action_key,
629629
device=device,
630630
)
631-
self.register_buffer("eps_init", torch.tensor([eps_init], device=device))
632-
self.register_buffer("eps_end", torch.tensor([eps_end], device=device))
631+
self.register_buffer("eps_init", torch.tensor(eps_init, device=device))
632+
self.register_buffer("eps_end", torch.tensor(eps_end, device=device))
633633
if self.eps_end > self.eps_init:
634634
raise ValueError(
635635
"eps should decrease over time or be constant, "
636636
f"got eps_init={eps_init} and eps_end={eps_end}"
637637
)
638638
self.annealing_num_steps = annealing_num_steps
639639
self.register_buffer(
640-
"eps", torch.tensor([eps_init], dtype=torch.float32, device=device)
640+
"eps", torch.tensor(eps_init, dtype=torch.float32, device=device)
641641
)
642642
self.out_keys = list(self.td_module.out_keys) + self.ou.out_keys
643643
self.is_init_key = is_init_key
@@ -840,16 +840,16 @@ def __init__(
840840
device=device,
841841
)
842842

843-
self.register_buffer("eps_init", torch.tensor([eps_init], device=device))
844-
self.register_buffer("eps_end", torch.tensor([eps_end], device=device))
843+
self.register_buffer("eps_init", torch.tensor(eps_init, device=device))
844+
self.register_buffer("eps_end", torch.tensor(eps_end, device=device))
845845
if self.eps_end > self.eps_init:
846846
raise ValueError(
847847
"eps should decrease over time or be constant, "
848848
f"got eps_init={eps_init} and eps_end={eps_end}"
849849
)
850850
self.annealing_num_steps = annealing_num_steps
851851
self.register_buffer(
852-
"eps", torch.tensor([eps_init], dtype=torch.float32, device=device)
852+
"eps", torch.tensor(eps_init, dtype=torch.float32, device=device)
853853
)
854854

855855
self.in_keys = [self.ou.key]

0 commit comments

Comments
 (0)