File tree Expand file tree Collapse file tree 3 files changed +35
-2
lines changed Expand file tree Collapse file tree 3 files changed +35
-2
lines changed Original file line number Diff line number Diff line change @@ -81,6 +81,38 @@ def strtobool(val: Any) -> bool:
81
81
82
82
BATCHED_PIPE_TIMEOUT = float (os .environ .get ("BATCHED_PIPE_TIMEOUT" , "10000.0" ))
83
83
84
+ _TORCH_DTYPES = (
85
+ torch .bfloat16 ,
86
+ torch .bool ,
87
+ torch .complex128 ,
88
+ torch .complex32 ,
89
+ torch .complex64 ,
90
+ torch .float16 ,
91
+ torch .float32 ,
92
+ torch .float64 ,
93
+ torch .int16 ,
94
+ torch .int32 ,
95
+ torch .int64 ,
96
+ torch .int8 ,
97
+ torch .qint32 ,
98
+ torch .qint8 ,
99
+ torch .quint4x2 ,
100
+ torch .quint8 ,
101
+ torch .uint8 ,
102
+ )
103
+ if hasattr (torch , "uint16" ):
104
+ _TORCH_DTYPES = _TORCH_DTYPES + (torch .uint16 ,)
105
+ if hasattr (torch , "uint32" ):
106
+ _TORCH_DTYPES = _TORCH_DTYPES + (torch .uint32 ,)
107
+ if hasattr (torch , "uint64" ):
108
+ _TORCH_DTYPES = _TORCH_DTYPES + (torch .uint64 ,)
109
+ _STR_DTYPE_TO_DTYPE = {str (dtype ): dtype for dtype in _TORCH_DTYPES }
110
+ _STRDTYPE2DTYPE = _STR_DTYPE_TO_DTYPE
111
+ _DTYPE_TO_STR_DTYPE = {
112
+ dtype : str_dtype for str_dtype , dtype in _STR_DTYPE_TO_DTYPE .items ()
113
+ }
114
+ _DTYPE2STRDTYPE = _STR_DTYPE_TO_DTYPE
115
+
84
116
85
117
class timeit :
86
118
"""A dirty but easy to use decorator for profiling code."""
Original file line number Diff line number Diff line change 18
18
TensorDict ,
19
19
)
20
20
from tensordict .memmap import MemoryMappedTensor
21
- from tensordict . utils import _STRDTYPE2DTYPE
21
+ from torchrl . _utils import _STRDTYPE2DTYPE
22
22
23
23
from torchrl .data .replay_buffers .utils import (
24
24
_save_pytree ,
Original file line number Diff line number Diff line change 16
16
import numpy as np
17
17
import torch
18
18
from tensordict import is_tensor_collection , MemoryMappedTensor , TensorDictBase
19
- from tensordict .utils import _STRDTYPE2DTYPE , expand_as_right , is_tensorclass
19
+ from tensordict .utils import expand_as_right , is_tensorclass
20
20
from torch import multiprocessing as mp
21
+ from torchrl ._utils import _STRDTYPE2DTYPE
21
22
22
23
try :
23
24
from torch .utils ._pytree import tree_leaves
You can’t perform that action at this time.
0 commit comments