We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
device_safe
1 parent 47dfefa commit fd97614Copy full SHA for fd97614
torchrl/data/replay_buffers/replay_buffers.py
@@ -75,7 +75,10 @@ def stack_tensors(list_of_tensor_iterators: List) -> Tuple[torch.Tensor]:
75
76
77
def _pin_memory(output: Any) -> Any:
78
- if hasattr(output, "pin_memory") and output.device_safe() == torch.device("cpu"):
+ output_device = (
79
+ output.device_safe() if hasattr(output, "device_safe") else output.device
80
+ )
81
+ if hasattr(output, "pin_memory") and output_device == torch.device("cpu"):
82
return output.pin_memory()
83
else:
84
return output
0 commit comments