Skip to content

Commit 3dbd84c

Browse files
author
Vincent Moens
committed
[Quality] Local dtype maps
ghstack-source-id: 26eb314 Pull-Request-resolved: #2936
1 parent 6ef7f64 commit 3dbd84c

File tree

3 files changed

+35
-2
lines changed

3 files changed

+35
-2
lines changed

torchrl/_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,38 @@ def strtobool(val: Any) -> bool:
8181

8282
BATCHED_PIPE_TIMEOUT = float(os.environ.get("BATCHED_PIPE_TIMEOUT", "10000.0"))
8383

84+
_TORCH_DTYPES = (
85+
torch.bfloat16,
86+
torch.bool,
87+
torch.complex128,
88+
torch.complex32,
89+
torch.complex64,
90+
torch.float16,
91+
torch.float32,
92+
torch.float64,
93+
torch.int16,
94+
torch.int32,
95+
torch.int64,
96+
torch.int8,
97+
torch.qint32,
98+
torch.qint8,
99+
torch.quint4x2,
100+
torch.quint8,
101+
torch.uint8,
102+
)
103+
if hasattr(torch, "uint16"):
104+
_TORCH_DTYPES = _TORCH_DTYPES + (torch.uint16,)
105+
if hasattr(torch, "uint32"):
106+
_TORCH_DTYPES = _TORCH_DTYPES + (torch.uint32,)
107+
if hasattr(torch, "uint64"):
108+
_TORCH_DTYPES = _TORCH_DTYPES + (torch.uint64,)
109+
_STR_DTYPE_TO_DTYPE = {str(dtype): dtype for dtype in _TORCH_DTYPES}
110+
_STRDTYPE2DTYPE = _STR_DTYPE_TO_DTYPE
111+
_DTYPE_TO_STR_DTYPE = {
112+
dtype: str_dtype for str_dtype, dtype in _STR_DTYPE_TO_DTYPE.items()
113+
}
114+
_DTYPE2STRDTYPE = _STR_DTYPE_TO_DTYPE
115+
84116

85117
class timeit:
86118
"""A dirty but easy to use decorator for profiling code."""

torchrl/data/replay_buffers/checkpointers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
TensorDict,
1919
)
2020
from tensordict.memmap import MemoryMappedTensor
21-
from tensordict.utils import _STRDTYPE2DTYPE
21+
from torchrl._utils import _STRDTYPE2DTYPE
2222

2323
from torchrl.data.replay_buffers.utils import (
2424
_save_pytree,

torchrl/data/replay_buffers/writers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
import numpy as np
1717
import torch
1818
from tensordict import is_tensor_collection, MemoryMappedTensor, TensorDictBase
19-
from tensordict.utils import _STRDTYPE2DTYPE, expand_as_right, is_tensorclass
19+
from tensordict.utils import expand_as_right, is_tensorclass
2020
from torch import multiprocessing as mp
21+
from torchrl._utils import _STRDTYPE2DTYPE
2122

2223
try:
2324
from torch.utils._pytree import tree_leaves

0 commit comments

Comments
 (0)