Skip to content

Commit fd97614

Browse files
authored
[BugFix] Wrong call to device_safe in replay buffer code (#454)
1 parent 47dfefa commit fd97614

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,10 @@ def stack_tensors(list_of_tensor_iterators: List) -> Tuple[torch.Tensor]:
7575

7676

7777
def _pin_memory(output: Any) -> Any:
78-
if hasattr(output, "pin_memory") and output.device_safe() == torch.device("cpu"):
78+
output_device = (
79+
output.device_safe() if hasattr(output, "device_safe") else output.device
80+
)
81+
if hasattr(output, "pin_memory") and output_device == torch.device("cpu"):
7982
return output.pin_memory()
8083
else:
8184
return output

0 commit comments

Comments
 (0)