Skip to content

Commit db7f08d

Browse files
author
Vincent Moens
committed
[Refactor] compile compatibility improvements
ghstack-source-id: 95f8241 Pull Request resolved: #2578
1 parent 507766a commit db7f08d

File tree

15 files changed

+176
-127
lines changed

15 files changed

+176
-127
lines changed

test/test_collector.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3172,28 +3172,6 @@ def make_and_test_policy(
31723172
)
31733173

31743174

3175-
@pytest.mark.parametrize(
3176-
"ctype", [SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector]
3177-
)
3178-
def test_no_stopiteration(ctype):
3179-
# Tests that there is no StopIteration raised and that the length of the collector is properly set
3180-
if ctype is SyncDataCollector:
3181-
envs = SerialEnv(16, CountingEnv)
3182-
else:
3183-
envs = [SerialEnv(8, CountingEnv), SerialEnv(8, CountingEnv)]
3184-
3185-
collector = ctype(create_env_fn=envs, frames_per_batch=173, total_frames=300)
3186-
try:
3187-
c_iter = iter(collector)
3188-
for i in range(len(collector)): # noqa: B007
3189-
c = next(c_iter)
3190-
assert c is not None
3191-
assert i == 1
3192-
finally:
3193-
collector.shutdown()
3194-
del collector
3195-
3196-
31973175
if __name__ == "__main__":
31983176
args, unknown = argparse.ArgumentParser().parse_known_args()
31993177
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/collectors/collectors.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta):
147147
_iterator = None
148148
total_frames: int
149149
frames_per_batch: int
150-
requested_frames_per_batch: int
151150
trust_policy: bool
152151
compiled_policy: bool
153152
cudagraphed_policy: bool
@@ -306,7 +305,7 @@ def __class_getitem__(self, index):
306305

307306
def __len__(self) -> int:
308307
if self.total_frames > 0:
309-
return -(self.total_frames // -self.requested_frames_per_batch)
308+
return -(self.total_frames // -self.frames_per_batch)
310309
raise RuntimeError("Non-terminating collectors do not have a length")
311310

312311

@@ -701,7 +700,7 @@ def __init__(
701700
remainder = total_frames % frames_per_batch
702701
if remainder != 0 and RL_WARNINGS:
703702
warnings.warn(
704-
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). "
703+
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch})."
705704
f"This means {frames_per_batch - remainder} additional frames will be collected."
706705
"To silence this message, set the environment variable RL_WARNINGS to False."
707706
)

torchrl/data/tensor_specs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2312,10 +2312,10 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Bounded:
23122312
dest_device = torch.device(dest)
23132313
if dest_device == self.device and dest_dtype == self.dtype:
23142314
return self
2315-
self.space.device = dest_device
2315+
space = self.space.to(dest_device)
23162316
return Bounded(
2317-
low=self.space.low,
2318-
high=self.space.high,
2317+
low=space.low,
2318+
high=space.high,
23192319
shape=self.shape,
23202320
device=dest_device,
23212321
dtype=dest_dtype,

torchrl/envs/batched_envs.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,12 +1356,15 @@ def _start_workers(self) -> None:
13561356

13571357
from torchrl.envs.env_creator import EnvCreator
13581358

1359+
num_threads = max(
1360+
1, torch.get_num_threads() - self.num_workers
1361+
) # 1 more thread for this proc
1362+
13591363
if self.num_threads is None:
1360-
self.num_threads = max(
1361-
1, torch.get_num_threads() - self.num_workers
1362-
) # 1 more thread for this proc
1364+
self.num_threads = num_threads
13631365

1364-
torch.set_num_threads(self.num_threads)
1366+
if self.num_threads != torch.get_num_threads():
1367+
torch.set_num_threads(self.num_threads)
13651368

13661369
if self._mp_start_method is not None:
13671370
ctx = mp.get_context(self._mp_start_method)

torchrl/modules/distributions/continuous.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,6 @@ def __init__(
397397
event_dims: int | None = None,
398398
tanh_loc: bool = False,
399399
safe_tanh: bool = True,
400-
**kwargs,
401400
):
402401
if not isinstance(loc, torch.Tensor):
403402
loc = torch.as_tensor(loc, dtype=torch.get_default_dtype())
@@ -683,6 +682,7 @@ def __init__(
683682
event_dims: int = 1,
684683
atol: float = 1e-6,
685684
rtol: float = 1e-6,
685+
safe: bool = True,
686686
):
687687
minmax_msg = "high value has been found to be equal or less than low value"
688688
if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor):
@@ -695,12 +695,19 @@ def __init__(
695695
if not all(high > low):
696696
raise ValueError(minmax_msg)
697697

698-
t = SafeTanhTransform()
699-
non_trivial_min = (isinstance(low, torch.Tensor) and (low != -1.0).any()) or (
700-
not isinstance(low, torch.Tensor) and low != -1.0
698+
if safe:
699+
if is_dynamo_compiling():
700+
_err_compile_safetanh()
701+
t = SafeTanhTransform()
702+
else:
703+
t = torch.distributions.TanhTransform()
704+
non_trivial_min = is_dynamo_compiling or (
705+
(isinstance(low, torch.Tensor) and (low != -1.0).any())
706+
or (not isinstance(low, torch.Tensor) and low != -1.0)
701707
)
702-
non_trivial_max = (isinstance(high, torch.Tensor) and (high != 1.0).any()) or (
703-
not isinstance(high, torch.Tensor) and high != 1.0
708+
non_trivial_max = is_dynamo_compiling or (
709+
(isinstance(high, torch.Tensor) and (high != 1.0).any())
710+
or (not isinstance(high, torch.Tensor) and high != 1.0)
704711
)
705712
self.non_trivial = non_trivial_min or non_trivial_max
706713

@@ -778,7 +785,7 @@ def _uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor:
778785
def _err_compile_safetanh():
779786
raise RuntimeError(
780787
"safe_tanh=True in TanhNormal is not compatible with torch.compile with torch pre 2.6.0. "
781-
"To deactivate it, pass safe_tanh=False. "
788+
" To deactivate it, pass safe_tanh=False. "
782789
"If you are using a ProbabilisticTensorDictModule, this can be done via "
783790
"`distribution_kwargs={'safe_tanh': False}`. "
784791
"See https://github.com/pytorch/pytorch/issues/133529 for more details."

torchrl/modules/distributions/utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
from torch import autograd, distributions as d
1010
from torch.distributions import Independent, Transform, TransformedDistribution
1111

12+
try:
13+
from torch.compiler import is_dynamo_compiling
14+
except ImportError:
15+
from torch._dynamo import is_compiling as is_dynamo_compiling
16+
1217

1318
def _cast_device(elt: Union[torch.Tensor, float], device) -> Union[torch.Tensor, float]:
1419
if isinstance(elt, torch.Tensor):
@@ -40,10 +45,12 @@ class FasterTransformedDistribution(TransformedDistribution):
4045
__doc__ = __doc__ + TransformedDistribution.__doc__
4146

4247
def __init__(self, base_distribution, transforms, validate_args=None):
48+
if is_dynamo_compiling():
49+
return super().__init__(
50+
base_distribution, transforms, validate_args=validate_args
51+
)
4352
if isinstance(transforms, Transform):
44-
self.transforms = [
45-
transforms,
46-
]
53+
self.transforms = [transforms]
4754
elif isinstance(transforms, list):
4855
raise ValueError("Make a ComposeTransform first.")
4956
else:

torchrl/modules/models/decision_transformer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,12 @@ def __init__(
9090
state_dim,
9191
action_dim,
9292
config: dict | DTConfig = None,
93+
device: torch.device | None = None,
9394
):
95+
if device is not None:
96+
with torch.device(device):
97+
return self.__init__(state_dim, action_dim, config)
98+
9499
if not _has_transformers:
95100
raise ImportError(
96101
"transformers is not installed. Please install it with `pip install transformers`."

torchrl/modules/tensordict_module/actors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1783,6 +1783,7 @@ class DecisionTransformerInferenceWrapper(TensorDictModuleWrapper):
17831783
For example for an observation input of shape [batch_size, context, obs_dim] with context=20 and inference_context=5, the first 15 entries
17841784
of the context will be masked. Defaults to 5.
17851785
spec (Optional[TensorSpec]): The spec of the input TensorDict. If None, it will be inferred from the policy module.
1786+
device (torch.device, optional): if provided, the device where the buffers / specs will be placed.
17861787
17871788
Examples:
17881789
>>> import torch
@@ -1836,6 +1837,7 @@ def __init__(
18361837
*,
18371838
inference_context: int = 5,
18381839
spec: Optional[TensorSpec] = None,
1840+
device: torch.device | None = None,
18391841
):
18401842
super().__init__(policy)
18411843
self.observation_key = "observation"
@@ -1857,6 +1859,8 @@ def __init__(
18571859
self._spec[self.action_key] = None
18581860
else:
18591861
self._spec = Composite({key: None for key in policy.out_keys})
1862+
if device is not None:
1863+
self._spec = self._spec.to(device)
18601864
self.checked = False
18611865

18621866
@property

torchrl/modules/tensordict_module/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def _forward_hook_safe_action(module, tensordict_in, tensordict_out):
6969
keys = [out_key]
7070
values = [spec]
7171
else:
72-
keys = list(spec.keys(True, True))
72+
# Make dynamo happy with the list creation
73+
keys = [key for key in spec.keys(True, True)] # noqa: C416
7374
values = [spec[key] for key in keys]
7475
for _spec, _key in zip(values, keys):
7576
if _spec is None:

0 commit comments

Comments
 (0)