Skip to content

Commit 244f93a

Browse files
author
Vincent Moens
authored
[Feature] Support for GRU (#1586)
1 parent f62785b commit 244f93a

File tree

5 files changed

+682
-24
lines changed

5 files changed

+682
-24
lines changed

docs/source/reference/modules.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ algorithms, such as DQN, DDPG or Dreamer.
332332
DistributionalDQNnet
333333
DreamerActor
334334
DuelingCnnDQNet
335+
GRUModule
335336
LSTMModule
336337
ObsDecoder
337338
ObsEncoder

test/test_tensordictmodules.py

Lines changed: 260 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
AdditiveGaussianWrapper,
2727
DecisionTransformerInferenceWrapper,
2828
DTActor,
29+
GRUModule,
2930
LSTMModule,
3031
MLP,
3132
NormalParamWrapper,
@@ -1645,9 +1646,9 @@ def test_set_temporal_mode(self):
16451646
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
16461647
)
16471648
assert lstm_module.set_recurrent_mode(False) is lstm_module
1648-
assert not lstm_module.set_recurrent_mode(False).temporal_mode
1649+
assert not lstm_module.set_recurrent_mode(False).recurrent_mode
16491650
assert lstm_module.set_recurrent_mode(True) is not lstm_module
1650-
assert lstm_module.set_recurrent_mode(True).temporal_mode
1651+
assert lstm_module.set_recurrent_mode(True).recurrent_mode
16511652
assert set(lstm_module.set_recurrent_mode(True).parameters()) == set(
16521653
lstm_module.parameters()
16531654
)
@@ -1822,6 +1823,263 @@ def create_transformed_env():
18221823
assert (data.get(("next", "recurrent_state_c")) != 0.0).all()
18231824

18241825

1826+
class TestGRUModule:
1827+
def test_errs(self):
1828+
with pytest.raises(ValueError, match="batch_first"):
1829+
gru_module = GRUModule(
1830+
input_size=3,
1831+
hidden_size=12,
1832+
batch_first=False,
1833+
in_keys=["observation", "hidden"],
1834+
out_keys=["intermediate", ("next", "hidden")],
1835+
)
1836+
with pytest.raises(ValueError, match="in_keys"):
1837+
gru_module = GRUModule(
1838+
input_size=3,
1839+
hidden_size=12,
1840+
batch_first=True,
1841+
in_keys=[
1842+
"observation",
1843+
"hidden0",
1844+
"hidden1",
1845+
],
1846+
out_keys=["intermediate", ("next", "hidden")],
1847+
)
1848+
with pytest.raises(TypeError, match="incompatible function arguments"):
1849+
gru_module = GRUModule(
1850+
input_size=3,
1851+
hidden_size=12,
1852+
batch_first=True,
1853+
in_keys="abc",
1854+
out_keys=["intermediate", ("next", "hidden")],
1855+
)
1856+
with pytest.raises(ValueError, match="in_keys"):
1857+
gru_module = GRUModule(
1858+
input_size=3,
1859+
hidden_size=12,
1860+
batch_first=True,
1861+
in_key="smth",
1862+
in_keys=["observation", "hidden0", "hidden1"],
1863+
out_keys=["intermediate", ("next", "hidden")],
1864+
)
1865+
with pytest.raises(ValueError, match="out_keys"):
1866+
gru_module = GRUModule(
1867+
input_size=3,
1868+
hidden_size=12,
1869+
batch_first=True,
1870+
in_keys=["observation", "hidden"],
1871+
out_keys=["intermediate", ("next", "hidden"), "other"],
1872+
)
1873+
with pytest.raises(TypeError, match="incompatible function arguments"):
1874+
gru_module = GRUModule(
1875+
input_size=3,
1876+
hidden_size=12,
1877+
batch_first=True,
1878+
in_keys=["observation", "hidden"],
1879+
out_keys="abc",
1880+
)
1881+
with pytest.raises(ValueError, match="out_keys"):
1882+
gru_module = GRUModule(
1883+
input_size=3,
1884+
hidden_size=12,
1885+
batch_first=True,
1886+
in_keys=["observation", "hidden"],
1887+
out_key="smth",
1888+
out_keys=["intermediate", ("next", "hidden"), "other"],
1889+
)
1890+
gru_module = GRUModule(
1891+
input_size=3,
1892+
hidden_size=12,
1893+
batch_first=True,
1894+
in_keys=["observation", "hidden"],
1895+
out_keys=["intermediate", ("next", "hidden")],
1896+
)
1897+
td = TensorDict({"observation": torch.randn(3)}, [])
1898+
with pytest.raises(KeyError, match="is_init"):
1899+
gru_module(td)
1900+
1901+
def test_set_temporal_mode(self):
1902+
gru_module = GRUModule(
1903+
input_size=3,
1904+
hidden_size=12,
1905+
batch_first=True,
1906+
in_keys=["observation", "hidden"],
1907+
out_keys=["intermediate", ("next", "hidden")],
1908+
)
1909+
assert gru_module.set_recurrent_mode(False) is gru_module
1910+
assert not gru_module.set_recurrent_mode(False).recurrent_mode
1911+
assert gru_module.set_recurrent_mode(True) is not gru_module
1912+
assert gru_module.set_recurrent_mode(True).recurrent_mode
1913+
assert set(gru_module.set_recurrent_mode(True).parameters()) == set(
1914+
gru_module.parameters()
1915+
)
1916+
1917+
def test_noncontiguous(self):
1918+
gru_module = GRUModule(
1919+
input_size=3,
1920+
hidden_size=12,
1921+
batch_first=True,
1922+
in_keys=["bork", "h"],
1923+
out_keys=["dork", ("next", "h")],
1924+
)
1925+
td = TensorDict(
1926+
{
1927+
"bork": torch.randn(3, 3),
1928+
"is_init": torch.zeros(3, 1, dtype=torch.bool),
1929+
},
1930+
[3],
1931+
)
1932+
padded = pad(td, [0, 5])
1933+
gru_module(padded)
1934+
1935+
@pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]])
1936+
def test_singel_step(self, shape):
1937+
td = TensorDict(
1938+
{
1939+
"observation": torch.zeros(*shape, 3),
1940+
"is_init": torch.zeros(*shape, 1, dtype=torch.bool),
1941+
},
1942+
shape,
1943+
)
1944+
gru_module = GRUModule(
1945+
input_size=3,
1946+
hidden_size=12,
1947+
batch_first=True,
1948+
in_keys=["observation", "hidden"],
1949+
out_keys=["intermediate", ("next", "hidden")],
1950+
)
1951+
td = gru_module(td)
1952+
td_next = step_mdp(td, keep_other=True)
1953+
td_next = gru_module(td_next)
1954+
1955+
assert not torch.isclose(td_next["next", "hidden"], td["next", "hidden"]).any()
1956+
1957+
@pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]])
1958+
@pytest.mark.parametrize("t", [1, 10])
1959+
def test_single_step_vs_multi(self, shape, t):
1960+
td = TensorDict(
1961+
{
1962+
"observation": torch.arange(t, dtype=torch.float32)
1963+
.unsqueeze(-1)
1964+
.expand(*shape, t, 3),
1965+
"is_init": torch.zeros(*shape, t, 1, dtype=torch.bool),
1966+
},
1967+
[*shape, t],
1968+
)
1969+
gru_module_ss = GRUModule(
1970+
input_size=3,
1971+
hidden_size=12,
1972+
batch_first=True,
1973+
in_keys=["observation", "hidden"],
1974+
out_keys=["intermediate", ("next", "hidden")],
1975+
)
1976+
gru_module_ms = gru_module_ss.set_recurrent_mode()
1977+
gru_module_ms(td)
1978+
td_ss = TensorDict(
1979+
{
1980+
"observation": torch.zeros(*shape, 3),
1981+
"is_init": torch.zeros(*shape, 1, dtype=torch.bool),
1982+
},
1983+
shape,
1984+
)
1985+
for _t in range(t):
1986+
gru_module_ss(td_ss)
1987+
td_ss = step_mdp(td_ss, keep_other=True)
1988+
td_ss["observation"][:] = _t + 1
1989+
torch.testing.assert_close(td_ss["hidden"], td["next", "hidden"][..., -1, :, :])
1990+
1991+
@pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]])
1992+
def test_multi_consecutive(self, shape):
1993+
t = 20
1994+
td = TensorDict(
1995+
{
1996+
"observation": torch.arange(t, dtype=torch.float32)
1997+
.unsqueeze(-1)
1998+
.expand(*shape, t, 3),
1999+
"is_init": torch.zeros(*shape, t, 1, dtype=torch.bool),
2000+
},
2001+
[*shape, t],
2002+
)
2003+
if shape:
2004+
td["is_init"][0, ..., 13, :] = True
2005+
else:
2006+
td["is_init"][13, :] = True
2007+
2008+
gru_module_ss = GRUModule(
2009+
input_size=3,
2010+
hidden_size=12,
2011+
batch_first=True,
2012+
in_keys=["observation", "hidden"],
2013+
out_keys=["intermediate", ("next", "hidden")],
2014+
)
2015+
gru_module_ms = gru_module_ss.set_recurrent_mode()
2016+
gru_module_ms(td)
2017+
td_ss = TensorDict(
2018+
{
2019+
"observation": torch.zeros(*shape, 3),
2020+
"is_init": torch.zeros(*shape, 1, dtype=torch.bool),
2021+
},
2022+
shape,
2023+
)
2024+
for _t in range(t):
2025+
td_ss["is_init"][:] = td["is_init"][..., _t, :]
2026+
gru_module_ss(td_ss)
2027+
td_ss = step_mdp(td_ss, keep_other=True)
2028+
td_ss["observation"][:] = _t + 1
2029+
torch.testing.assert_close(
2030+
td_ss["intermediate"], td["intermediate"][..., -1, :]
2031+
)
2032+
2033+
def test_gru_parallel_env(self):
2034+
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv
2035+
2036+
# tests that hidden states are carried over with parallel envs
2037+
gru_module = GRUModule(
2038+
input_size=7,
2039+
hidden_size=12,
2040+
num_layers=2,
2041+
in_key="observation",
2042+
out_key="features",
2043+
)
2044+
2045+
def create_transformed_env():
2046+
primer = gru_module.make_tensordict_primer()
2047+
env = DiscreteActionVecMockEnv(categorical_action_encoding=True)
2048+
env = TransformedEnv(env)
2049+
env.append_transform(InitTracker())
2050+
env.append_transform(primer)
2051+
return env
2052+
2053+
env = ParallelEnv(
2054+
create_env_fn=create_transformed_env,
2055+
num_workers=2,
2056+
)
2057+
2058+
mlp = TensorDictModule(
2059+
MLP(
2060+
in_features=12,
2061+
out_features=7,
2062+
num_cells=[],
2063+
),
2064+
in_keys=["features"],
2065+
out_keys=["logits"],
2066+
)
2067+
2068+
actor_model = TensorDictSequential(gru_module, mlp)
2069+
2070+
actor = ProbabilisticActor(
2071+
module=actor_model,
2072+
in_keys=["logits"],
2073+
out_keys=["action"],
2074+
distribution_class=torch.distributions.Categorical,
2075+
return_log_prob=True,
2076+
)
2077+
for break_when_any_done in [False, True]:
2078+
data = env.rollout(10, actor, break_when_any_done=break_when_any_done)
2079+
assert (data.get("recurrent_state") != 0.0).any()
2080+
assert (data.get(("next", "recurrent_state")) != 0.0).all()
2081+
2082+
18252083
def test_safe_specs():
18262084

18272085
out_key = ("a", "b")

torchrl/modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
DistributionalQValueModule,
5757
EGreedyModule,
5858
EGreedyWrapper,
59+
GRUModule,
5960
LMHeadActorValueOperator,
6061
LSTMModule,
6162
OrnsteinUhlenbeckProcessWrapper,

torchrl/modules/tensordict_module/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,6 @@
3131
SafeProbabilisticModule,
3232
SafeProbabilisticTensorDictSequential,
3333
)
34-
from .rnn import LSTMModule
34+
from .rnn import GRUModule, LSTMModule
3535
from .sequence import SafeSequential
3636
from .world_models import WorldModelWrapper

0 commit comments

Comments
 (0)