Skip to content

Commit 14b2775

Browse files
author
Vincent Moens
committed
[Refactor] Deprecate recurrent_mode API to use decorators/CMs instead
ghstack-source-id: 80f705e Pull Request resolved: #2584
1 parent 0f59226 commit 14b2775

File tree

12 files changed

+230
-33
lines changed

12 files changed

+230
-33
lines changed

docs/source/reference/modules.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,8 @@ algorithms, such as DQN, DDPG or Dreamer.
373373
OnlineDTActor
374374
RSSMPosterior
375375
RSSMPrior
376+
set_recurrent_mode
377+
recurrent_mode
376378

377379
Multi-agent-specific modules
378380
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

test/test_cost.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
DistributionalQValueActor,
4848
OneHotCategorical,
4949
QValueActor,
50+
recurrent_mode,
5051
SafeSequential,
5152
WorldModelWrapper,
5253
)
@@ -15507,6 +15508,29 @@ def test_set_deprecated_keys(self, adv, kwargs):
1550715508

1550815509

1550915510
class TestBase:
15511+
def test_decorators(self):
15512+
class MyLoss(LossModule):
15513+
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
15514+
assert recurrent_mode()
15515+
assert exploration_type() is ExplorationType.DETERMINISTIC
15516+
return TensorDict()
15517+
15518+
def actor_loss(self, tensordict: TensorDictBase) -> TensorDictBase:
15519+
assert recurrent_mode()
15520+
assert exploration_type() is ExplorationType.DETERMINISTIC
15521+
return TensorDict()
15522+
15523+
def something_loss(self, tensordict: TensorDictBase) -> TensorDictBase:
15524+
assert recurrent_mode()
15525+
assert exploration_type() is ExplorationType.DETERMINISTIC
15526+
return TensorDict()
15527+
15528+
loss = MyLoss()
15529+
loss.forward(None)
15530+
loss.actor_loss(None)
15531+
loss.something_loss(None)
15532+
assert not recurrent_mode()
15533+
1551015534
@pytest.mark.parametrize("expand_dim", [None, 2])
1551115535
@pytest.mark.parametrize("compare_against", [True, False])
1551215536
@pytest.mark.skipif(not _has_functorch, reason="functorch is needed for expansion")

test/test_tensordictmodules.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
OnlineDTActor,
3737
ProbabilisticActor,
3838
SafeModule,
39+
set_recurrent_mode,
3940
TanhDelta,
4041
TanhNormal,
4142
ValueOperator,
@@ -729,6 +730,31 @@ def test_errs(self):
729730
with pytest.raises(KeyError, match="is_init"):
730731
lstm_module(td)
731732

733+
@pytest.mark.parametrize("default_val", [False, True, None])
734+
def test_set_recurrent_mode(self, default_val):
735+
lstm_module = LSTMModule(
736+
input_size=3,
737+
hidden_size=12,
738+
batch_first=True,
739+
in_keys=["observation", "hidden0", "hidden1"],
740+
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
741+
default_recurrent_mode=default_val,
742+
)
743+
assert lstm_module.recurrent_mode is bool(default_val)
744+
with set_recurrent_mode(True):
745+
assert lstm_module.recurrent_mode
746+
with set_recurrent_mode(False):
747+
assert not lstm_module.recurrent_mode
748+
with set_recurrent_mode("recurrent"):
749+
assert lstm_module.recurrent_mode
750+
with set_recurrent_mode("sequential"):
751+
assert not lstm_module.recurrent_mode
752+
assert lstm_module.recurrent_mode
753+
assert not lstm_module.recurrent_mode
754+
assert lstm_module.recurrent_mode
755+
assert lstm_module.recurrent_mode is bool(default_val)
756+
757+
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
732758
def test_set_temporal_mode(self):
733759
lstm_module = LSTMModule(
734760
input_size=3,
@@ -754,7 +780,8 @@ def test_python_cudnn(self):
754780
num_layers=2,
755781
in_keys=["observation", "hidden0", "hidden1"],
756782
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
757-
).set_recurrent_mode(True)
783+
default_recurrent_mode=True,
784+
)
758785
obs = torch.rand(10, 20, 3)
759786

760787
hidden0 = torch.rand(10, 20, 2, 12)
@@ -1109,6 +1136,31 @@ def test_errs(self):
11091136
with pytest.raises(KeyError, match="is_init"):
11101137
gru_module(td)
11111138

1139+
@pytest.mark.parametrize("default_val", [False, True, None])
1140+
def test_set_recurrent_mode(self, default_val):
1141+
gru_module = GRUModule(
1142+
input_size=3,
1143+
hidden_size=12,
1144+
batch_first=True,
1145+
in_keys=["observation", "hidden"],
1146+
out_keys=["intermediate", ("next", "hidden")],
1147+
default_recurrent_mode=default_val,
1148+
)
1149+
assert gru_module.recurrent_mode is bool(default_val)
1150+
with set_recurrent_mode(True):
1151+
assert gru_module.recurrent_mode
1152+
with set_recurrent_mode(False):
1153+
assert not gru_module.recurrent_mode
1154+
with set_recurrent_mode("recurrent"):
1155+
assert gru_module.recurrent_mode
1156+
with set_recurrent_mode("sequential"):
1157+
assert not gru_module.recurrent_mode
1158+
assert gru_module.recurrent_mode
1159+
assert not gru_module.recurrent_mode
1160+
assert gru_module.recurrent_mode
1161+
assert gru_module.recurrent_mode is bool(default_val)
1162+
1163+
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
11121164
def test_set_temporal_mode(self):
11131165
gru_module = GRUModule(
11141166
input_size=3,

test/test_transforms.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10885,7 +10885,8 @@ def _make_gru_module(self, input_size=4, hidden_size=4, device="cpu"):
1088510885
in_keys=["observation", "rhs", "is_init"],
1088610886
out_keys=["output", ("next", "rhs")],
1088710887
device=device,
10888-
).set_recurrent_mode(True)
10888+
default_recurrent_mode=True,
10889+
)
1088910890

1089010891
def _make_lstm_module(self, input_size=4, hidden_size=4, device="cpu"):
1089110892
return LSTMModule(
@@ -10895,7 +10896,8 @@ def _make_lstm_module(self, input_size=4, hidden_size=4, device="cpu"):
1089510896
in_keys=["observation", "rhs_h", "rhs_c", "is_init"],
1089610897
out_keys=["output", ("next", "rhs_h"), ("next", "rhs_c")],
1089710898
device=device,
10898-
).set_recurrent_mode(True)
10899+
default_recurrent_mode=True,
10900+
)
1089910901

1090010902
def _make_batch(self, batch_size: int = 2, sequence_length: int = 5):
1090110903
observation = torch.randn(batch_size, sequence_length + 1, 4)

torchrl/_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
import os
1616
import pickle
1717
import sys
18+
import threading
1819
import time
1920
import traceback
2021
import warnings
22+
from contextlib import nullcontext
2123
from copy import copy
2224
from distutils.util import strtobool
2325
from functools import wraps
@@ -32,6 +34,11 @@
3234
from tensordict.utils import NestedKey
3335
from torch import multiprocessing as mp
3436

37+
try:
38+
from torch.compiler import is_compiling
39+
except ImportError:
40+
from torch._dynamo import is_compiling
41+
3542
LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "INFO")
3643
logger = logging.getLogger("torchrl")
3744
logger.setLevel(getattr(logging, LOGGING_LEVEL))
@@ -827,3 +834,19 @@ def _make_ordinal_device(device: torch.device):
827834
if device.type == "mps" and device.index is None:
828835
return torch.device("mps", index=0)
829836
return device
837+
838+
839+
class _ContextManager:
840+
def __init__(self):
841+
self._mode: Any | None = None
842+
self._lock = threading.Lock()
843+
844+
def get_mode(self) -> Any | None:
845+
cm = self._lock if not is_compiling() else nullcontext()
846+
with cm:
847+
return self._mode
848+
849+
def set_mode(self, type: Any | None) -> None:
850+
cm = self._lock if not is_compiling() else nullcontext()
851+
with cm:
852+
self._mode = type

torchrl/envs/transforms/transforms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7411,7 +7411,8 @@ class BurnInTransform(Transform):
74117411
... hidden_size=10,
74127412
... in_keys=["observation", "hidden"],
74137413
... out_keys=["intermediate", ("next", "hidden")],
7414-
... ).set_recurrent_mode(True)
7414+
... default_recurrent_mode=True,
7415+
... )
74157416
>>> burn_in_transform = BurnInTransform(
74167417
... modules=[gru_module],
74177418
... burn_in=5,

torchrl/modules/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,12 @@
8080
QValueActor,
8181
QValueHook,
8282
QValueModule,
83+
recurrent_mode,
8384
SafeModule,
8485
SafeProbabilisticModule,
8586
SafeProbabilisticTensorDictSequential,
8687
SafeSequential,
88+
set_recurrent_mode,
8789
TanhModule,
8890
ValueOperator,
8991
VmapModule,

torchrl/modules/tensordict_module/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@
3434
SafeProbabilisticModule,
3535
SafeProbabilisticTensorDictSequential,
3636
)
37-
from .rnn import GRU, GRUCell, GRUModule, LSTM, LSTMCell, LSTMModule
37+
from .rnn import (
38+
GRU,
39+
GRUCell,
40+
GRUModule,
41+
LSTM,
42+
LSTMCell,
43+
LSTMModule,
44+
recurrent_mode,
45+
set_recurrent_mode,
46+
)
3847
from .sequence import SafeSequential
3948
from .world_models import WorldModelWrapper

0 commit comments

Comments
 (0)