Skip to content

Commit 49039d1

Browse files
svarolgunesSerhat Varolgünes
andauthored
[Feature] Make DQN compatible with nn.Module (#632)
* Modification to let nn.Module type value_network arg in DQNLoss * Unit tests added to cover the feature * utility functions added * tests are updated, util fnc rectified * initializer docstrings are updated for loss classes * util functions added to rst file * lint fix Co-authored-by: Serhat Varolgünes <svarolgunes@fb.com>
1 parent e96fd37 commit 49039d1

File tree

5 files changed

+235
-19
lines changed

5 files changed

+235
-19
lines changed

docs/source/reference/modules.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ TensorDict modules
2323
ActorValueOperator
2424
ActorCriticOperator
2525
ActorCriticWrapper
26+
is_tensordict_compatible
27+
ensure_tensordict_compatible
2628

2729
Hooks
2830
-----

test/test_cost.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,13 @@ class TestDQN:
126126
seed = 0
127127

128128
def _create_mock_actor(
129-
self, action_spec_type, batch=2, obs_dim=3, action_dim=4, device="cpu"
129+
self,
130+
action_spec_type,
131+
batch=2,
132+
obs_dim=3,
133+
action_dim=4,
134+
device="cpu",
135+
is_nn_module=False,
130136
):
131137
# Actor
132138
if action_spec_type == "one_hot":
@@ -141,6 +147,8 @@ def _create_mock_actor(
141147
raise ValueError(f"Wrong {action_spec_type}")
142148

143149
module = nn.Linear(obs_dim, action_dim)
150+
if is_nn_module:
151+
return module.to(device)
144152
actor = QValueActor(
145153
spec=CompositeSpec(
146154
action=action_spec, action_value=None, chosen_action_value=None
@@ -158,6 +166,7 @@ def _create_mock_distributional_actor(
158166
atoms=5,
159167
vmin=1,
160168
vmax=5,
169+
is_nn_module=False,
161170
):
162171
# Actor
163172
if action_spec_type == "mult_one_hot":
@@ -170,6 +179,11 @@ def _create_mock_distributional_actor(
170179
raise ValueError(f"Wrong {action_spec_type}")
171180
support = torch.linspace(vmin, vmax, atoms, dtype=torch.float)
172181
module = MLP(obs_dim, (atoms, action_dim))
182+
# TODO: Fails tests with
183+
# TypeError: __init__() missing 1 required keyword-only argument: 'support'
184+
# DistributionalQValueActor initializer expects additional inputs.
185+
# if is_nn_module:
186+
# return module
173187
actor = DistributionalQValueActor(
174188
spec=CompositeSpec(action=action_spec, action_value=None),
175189
module=module,
@@ -272,10 +286,11 @@ def _create_seq_mock_data_dqn(
272286
@pytest.mark.parametrize(
273287
"action_spec_type", ("nd_bounded", "one_hot", "categorical")
274288
)
275-
def test_dqn(self, delay_value, device, action_spec_type):
289+
@pytest.mark.parametrize("is_nn_module", (False, True))
290+
def test_dqn(self, delay_value, device, action_spec_type, is_nn_module):
276291
torch.manual_seed(self.seed)
277292
actor = self._create_mock_actor(
278-
action_spec_type=action_spec_type, device=device
293+
action_spec_type=action_spec_type, device=device, is_nn_module=is_nn_module
279294
)
280295
td = self._create_mock_data_dqn(
281296
action_spec_type=action_spec_type, device=device
@@ -471,12 +486,13 @@ def test_dqn_batcher_nofunctorch(
471486
@pytest.mark.parametrize(
472487
"action_spec_type", ("mult_one_hot", "one_hot", "categorical")
473488
)
489+
@pytest.mark.parametrize("is_nn_module", (False, True))
474490
def test_distributional_dqn(
475-
self, atoms, delay_value, device, action_spec_type, gamma=0.9
491+
self, atoms, delay_value, device, action_spec_type, is_nn_module, gamma=0.9
476492
):
477493
torch.manual_seed(self.seed)
478494
actor = self._create_mock_distributional_actor(
479-
action_spec_type=action_spec_type, atoms=atoms
495+
action_spec_type=action_spec_type, atoms=atoms, is_nn_module=is_nn_module
480496
).to(device)
481497

482498
td = self._create_mock_data_dqn(

test/test_tensordictmodules.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
)
3232
from torchrl.envs.utils import set_exploration_mode
3333
from torchrl.modules import NormalParamWrapper, TanhNormal, TensorDictModule
34+
from torchrl.modules.tensordict_module.common import (
35+
is_tensordict_compatible,
36+
ensure_tensordict_compatible,
37+
)
3438
from torchrl.modules.tensordict_module.probabilistic import (
3539
ProbabilisticTensorDictModule,
3640
)
@@ -1842,3 +1846,110 @@ def test_subsequence_weight_update(self):
18421846
if __name__ == "__main__":
18431847
args, unknown = argparse.ArgumentParser().parse_known_args()
18441848
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
1849+
1850+
1851+
def test_is_tensordict_compatible():
1852+
class MultiHeadLinear(nn.Module):
1853+
def __init__(self, in_1, out_1, out_2, out_3):
1854+
super().__init__()
1855+
self.linear_1 = nn.Linear(in_1, out_1)
1856+
self.linear_2 = nn.Linear(in_1, out_2)
1857+
self.linear_3 = nn.Linear(in_1, out_3)
1858+
1859+
def forward(self, x):
1860+
return self.linear_1(x), self.linear_2(x), self.linear_3(x)
1861+
1862+
td_module = TensorDictModule(
1863+
MultiHeadLinear(5, 4, 3, 2),
1864+
in_keys=["in_1", "in_2"],
1865+
out_keys=["out_1", "out_2"],
1866+
)
1867+
assert is_tensordict_compatible(td_module)
1868+
1869+
class MockCompatibleModule(nn.Module):
1870+
def __init__(self, in_keys, out_keys):
1871+
self.in_keys = in_keys
1872+
self.out_keys = out_keys
1873+
1874+
def forward(self, tensordict):
1875+
pass
1876+
1877+
compatible_nn_module = MockCompatibleModule(
1878+
in_keys=["in_1", "in_2"],
1879+
out_keys=["out_1", "out_2"],
1880+
)
1881+
assert is_tensordict_compatible(compatible_nn_module)
1882+
1883+
class MockIncompatibleModuleNoKeys(nn.Module):
1884+
def forward(self, input):
1885+
pass
1886+
1887+
incompatible_nn_module_no_keys = MockIncompatibleModuleNoKeys()
1888+
assert not is_tensordict_compatible(incompatible_nn_module_no_keys)
1889+
1890+
class MockIncompatibleModuleMultipleArgs(nn.Module):
1891+
def __init__(self, in_keys, out_keys):
1892+
self.in_keys = in_keys
1893+
self.out_keys = out_keys
1894+
1895+
def forward(self, input_1, input_2):
1896+
pass
1897+
1898+
incompatible_nn_module_multi_args = MockIncompatibleModuleMultipleArgs(
1899+
in_keys=["in_1", "in_2"],
1900+
out_keys=["out_1", "out_2"],
1901+
)
1902+
with pytest.raises(TypeError):
1903+
is_tensordict_compatible(incompatible_nn_module_multi_args)
1904+
1905+
1906+
def test_ensure_tensordict_compatible():
1907+
class MultiHeadLinear(nn.Module):
1908+
def __init__(self, in_1, out_1, out_2, out_3):
1909+
super().__init__()
1910+
self.linear_1 = nn.Linear(in_1, out_1)
1911+
self.linear_2 = nn.Linear(in_1, out_2)
1912+
self.linear_3 = nn.Linear(in_1, out_3)
1913+
1914+
def forward(self, x):
1915+
return self.linear_1(x), self.linear_2(x), self.linear_3(x)
1916+
1917+
td_module = TensorDictModule(
1918+
MultiHeadLinear(5, 4, 3, 2),
1919+
in_keys=["in_1", "in_2"],
1920+
out_keys=["out_1", "out_2"],
1921+
)
1922+
ensured_module = ensure_tensordict_compatible(td_module)
1923+
assert ensured_module is td_module
1924+
with pytest.raises(TypeError):
1925+
ensure_tensordict_compatible(td_module, in_keys=["input"])
1926+
with pytest.raises(TypeError):
1927+
ensure_tensordict_compatible(td_module, out_keys=["output"])
1928+
1929+
class NonNNModule:
1930+
def __init__(self):
1931+
pass
1932+
1933+
def forward(self, x):
1934+
pass
1935+
1936+
non_nn_module = NonNNModule()
1937+
with pytest.raises(TypeError):
1938+
ensure_tensordict_compatible(non_nn_module)
1939+
1940+
class ErrorNNModule(nn.Module):
1941+
def forward(self, in_1, in_2):
1942+
pass
1943+
1944+
error_nn_module = ErrorNNModule()
1945+
with pytest.raises(TypeError):
1946+
ensure_tensordict_compatible(error_nn_module, in_keys=["input"])
1947+
1948+
nn_module = MultiHeadLinear(5, 4, 3, 2)
1949+
ensured_module = ensure_tensordict_compatible(
1950+
nn_module,
1951+
in_keys=["x"],
1952+
out_keys=["out_1", "out_2", "out_3"],
1953+
)
1954+
assert set(ensured_module.in_keys) == {"x"}
1955+
assert isinstance(ensured_module, TensorDictModule)

torchrl/modules/tensordict_module/common.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from __future__ import annotations
77

8+
import inspect
89
import warnings
910
from copy import deepcopy
1011
from textwrap import indent
@@ -14,6 +15,7 @@
1415
List,
1516
Optional,
1617
Sequence,
18+
Type,
1719
Union,
1820
)
1921

@@ -625,3 +627,83 @@ def __getattr__(self, name: str) -> Any:
625627

626628
def forward(self, *args, **kwargs):
627629
return self.td_module.forward(*args, **kwargs)
630+
631+
632+
def is_tensordict_compatible(module: Union[TensorDictModule, nn.Module]):
633+
sig = inspect.signature(module.forward)
634+
635+
if isinstance(module, TensorDictModule) or (
636+
len(sig.parameters) == 1
637+
and hasattr(module, "in_keys")
638+
and hasattr(module, "out_keys")
639+
):
640+
# if the module is a TensorDictModule or takes a single argument and defines
641+
# in_keys and out_keys then we assume it can already deal with TensorDict input
642+
# to forward and we return True
643+
return True
644+
elif not hasattr(module, "in_keys") and not hasattr(module, "out_keys"):
645+
# if it's not a TensorDictModule, and in_keys and out_keys are not defined then
646+
# we assume no TensorDict compatibility and will try to wrap it.
647+
return False
648+
649+
# if in_keys or out_keys were defined but module is not a TensorDictModule or
650+
# accepts multiple arguments then it's likely the user is trying to do something
651+
# that will have undetermined behaviour, we raise an error
652+
raise TypeError(
653+
"Received a module that defines in_keys or out_keys and also expects multiple "
654+
"arguments to module.forward. If the module is compatible with TensorDict, it "
655+
"should take a single argument of type TensorDict to module.forward and define "
656+
"both in_keys and out_keys. Alternatively, module.forward can accept "
657+
"arbitrarily many tensor inputs and leave in_keys and out_keys undefined and "
658+
"TorchRL will attempt to automatically wrap the module with a TensorDictModule."
659+
)
660+
661+
662+
def ensure_tensordict_compatible(
663+
module: Union[
664+
FunctionalModule, FunctionalModuleWithBuffers, TensorDictModule, nn.Module
665+
],
666+
in_keys: Optional[Iterable[str]] = None,
667+
out_keys: Optional[Iterable[str]] = None,
668+
safe: bool = False,
669+
wrapper_type: Optional[Type] = TensorDictModule,
670+
):
671+
"""Checks and ensures an object with forward method is TensorDict compatible."""
672+
if is_tensordict_compatible(module):
673+
if in_keys is not None and set(in_keys) != set(module.in_keys):
674+
raise TypeError(
675+
f"Arguments to module.forward ({set(module.in_keys)}) doesn't match "
676+
f"with the expected TensorDict in_keys ({set(in_keys)})."
677+
)
678+
if out_keys is not None and set(module.out_keys) != set(out_keys):
679+
raise TypeError(
680+
f"Outputs of module.forward ({set(module.out_keys)}) doesn't match "
681+
f"with the expected TensorDict out_keys ({set(out_keys)})."
682+
)
683+
# return module itself if it's already tensordict compatible
684+
return module
685+
686+
if not isinstance(module, nn.Module):
687+
raise TypeError(
688+
"Argument to ensure_tensordict_compatible should be either "
689+
"a TensorDictModule or an nn.Module"
690+
)
691+
692+
sig = inspect.signature(module.forward)
693+
if in_keys is not None and set(sig.parameters) != set(in_keys):
694+
raise TypeError(
695+
"Arguments to module.forward are incompatible with entries in "
696+
"env.observation_spec. If you want TorchRL to automatically "
697+
"wrap your module with a TensorDictModule then the arguments "
698+
"to module must correspond one-to-one with entries in "
699+
"in_keys. For more complex behaviour and more control you can "
700+
"consider writing your own TensorDictModule."
701+
)
702+
703+
# TODO: Check whether out_keys match (at least in number) if they are provided.
704+
kwargs = {}
705+
if in_keys is not None:
706+
kwargs["in_keys"] = in_keys
707+
if out_keys is not None:
708+
kwargs["out_keys"] = out_keys
709+
return wrapper_type(module, **kwargs)

0 commit comments

Comments
 (0)