Skip to content

Commit 8d2bc8b

Browse files
author
Vincent Moens
authored
[BugFix] Gracefully handle C++ import error in TorchRL (#1640)
1 parent d93551d commit 8d2bc8b

File tree

3 files changed

+19
-8
lines changed

3 files changed

+19
-8
lines changed

torchrl/_extension.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,10 @@ def _init_extension():
2323
if not is_module_available("torchrl._torchrl"):
2424
warnings.warn("torchrl C++ extension is not available.")
2525
return
26+
27+
28+
EXTENSION_WARNING = (
29+
"Failed to import torchrl C++ binaries. Some modules (eg, prioritized replay buffers) may not work with your installation. "
30+
"If you installed TorchRL from PyPI, please report the bug on TorchRL github. "
31+
"If you installed TorchRL locally and/or in development mode, check that you have all the required compiling packages."
32+
)

torchrl/data/replay_buffers/samplers.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,25 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5-
5+
import warnings
66
from abc import ABC, abstractmethod
77
from copy import deepcopy
88
from typing import Any, Dict, Tuple, Union
99

1010
import numpy as np
1111
import torch
1212

13-
from torchrl._torchrl import (
14-
MinSegmentTreeFp32,
15-
MinSegmentTreeFp64,
16-
SumSegmentTreeFp32,
17-
SumSegmentTreeFp64,
18-
)
13+
from ..._extension import EXTENSION_WARNING
14+
15+
try:
16+
from torchrl._torchrl import (
17+
MinSegmentTreeFp32,
18+
MinSegmentTreeFp64,
19+
SumSegmentTreeFp32,
20+
SumSegmentTreeFp64,
21+
)
22+
except ImportError:
23+
warnings.warn(EXTENSION_WARNING)
1924

2025
from .storages import Storage
2126
from .utils import _to_numpy, INT_CLASSES

torchrl/modules/distributions/continuous.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
TruncatedNormal as _TruncatedNormal,
1616
)
1717

18-
# from torchrl._torchrl import safeatanh, safetanh
1918
from torchrl.modules.distributions.utils import (
2019
_cast_device,
2120
FasterTransformedDistribution,

0 commit comments

Comments
 (0)