|
26 | 26 | AdditiveGaussianWrapper,
|
27 | 27 | DecisionTransformerInferenceWrapper,
|
28 | 28 | DTActor,
|
| 29 | + GRUModule, |
29 | 30 | LSTMModule,
|
30 | 31 | MLP,
|
31 | 32 | NormalParamWrapper,
|
@@ -1645,9 +1646,9 @@ def test_set_temporal_mode(self):
|
1645 | 1646 | out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
|
1646 | 1647 | )
|
1647 | 1648 | assert lstm_module.set_recurrent_mode(False) is lstm_module
|
1648 |
| - assert not lstm_module.set_recurrent_mode(False).temporal_mode |
| 1649 | + assert not lstm_module.set_recurrent_mode(False).recurrent_mode |
1649 | 1650 | assert lstm_module.set_recurrent_mode(True) is not lstm_module
|
1650 |
| - assert lstm_module.set_recurrent_mode(True).temporal_mode |
| 1651 | + assert lstm_module.set_recurrent_mode(True).recurrent_mode |
1651 | 1652 | assert set(lstm_module.set_recurrent_mode(True).parameters()) == set(
|
1652 | 1653 | lstm_module.parameters()
|
1653 | 1654 | )
|
@@ -1822,6 +1823,263 @@ def create_transformed_env():
|
1822 | 1823 | assert (data.get(("next", "recurrent_state_c")) != 0.0).all()
|
1823 | 1824 |
|
1824 | 1825 |
|
| 1826 | +class TestGRUModule: |
| 1827 | + def test_errs(self): |
| 1828 | + with pytest.raises(ValueError, match="batch_first"): |
| 1829 | + gru_module = GRUModule( |
| 1830 | + input_size=3, |
| 1831 | + hidden_size=12, |
| 1832 | + batch_first=False, |
| 1833 | + in_keys=["observation", "hidden"], |
| 1834 | + out_keys=["intermediate", ("next", "hidden")], |
| 1835 | + ) |
| 1836 | + with pytest.raises(ValueError, match="in_keys"): |
| 1837 | + gru_module = GRUModule( |
| 1838 | + input_size=3, |
| 1839 | + hidden_size=12, |
| 1840 | + batch_first=True, |
| 1841 | + in_keys=[ |
| 1842 | + "observation", |
| 1843 | + "hidden0", |
| 1844 | + "hidden1", |
| 1845 | + ], |
| 1846 | + out_keys=["intermediate", ("next", "hidden")], |
| 1847 | + ) |
| 1848 | + with pytest.raises(TypeError, match="incompatible function arguments"): |
| 1849 | + gru_module = GRUModule( |
| 1850 | + input_size=3, |
| 1851 | + hidden_size=12, |
| 1852 | + batch_first=True, |
| 1853 | + in_keys="abc", |
| 1854 | + out_keys=["intermediate", ("next", "hidden")], |
| 1855 | + ) |
| 1856 | + with pytest.raises(ValueError, match="in_keys"): |
| 1857 | + gru_module = GRUModule( |
| 1858 | + input_size=3, |
| 1859 | + hidden_size=12, |
| 1860 | + batch_first=True, |
| 1861 | + in_key="smth", |
| 1862 | + in_keys=["observation", "hidden0", "hidden1"], |
| 1863 | + out_keys=["intermediate", ("next", "hidden")], |
| 1864 | + ) |
| 1865 | + with pytest.raises(ValueError, match="out_keys"): |
| 1866 | + gru_module = GRUModule( |
| 1867 | + input_size=3, |
| 1868 | + hidden_size=12, |
| 1869 | + batch_first=True, |
| 1870 | + in_keys=["observation", "hidden"], |
| 1871 | + out_keys=["intermediate", ("next", "hidden"), "other"], |
| 1872 | + ) |
| 1873 | + with pytest.raises(TypeError, match="incompatible function arguments"): |
| 1874 | + gru_module = GRUModule( |
| 1875 | + input_size=3, |
| 1876 | + hidden_size=12, |
| 1877 | + batch_first=True, |
| 1878 | + in_keys=["observation", "hidden"], |
| 1879 | + out_keys="abc", |
| 1880 | + ) |
| 1881 | + with pytest.raises(ValueError, match="out_keys"): |
| 1882 | + gru_module = GRUModule( |
| 1883 | + input_size=3, |
| 1884 | + hidden_size=12, |
| 1885 | + batch_first=True, |
| 1886 | + in_keys=["observation", "hidden"], |
| 1887 | + out_key="smth", |
| 1888 | + out_keys=["intermediate", ("next", "hidden"), "other"], |
| 1889 | + ) |
| 1890 | + gru_module = GRUModule( |
| 1891 | + input_size=3, |
| 1892 | + hidden_size=12, |
| 1893 | + batch_first=True, |
| 1894 | + in_keys=["observation", "hidden"], |
| 1895 | + out_keys=["intermediate", ("next", "hidden")], |
| 1896 | + ) |
| 1897 | + td = TensorDict({"observation": torch.randn(3)}, []) |
| 1898 | + with pytest.raises(KeyError, match="is_init"): |
| 1899 | + gru_module(td) |
| 1900 | + |
| 1901 | + def test_set_temporal_mode(self): |
| 1902 | + gru_module = GRUModule( |
| 1903 | + input_size=3, |
| 1904 | + hidden_size=12, |
| 1905 | + batch_first=True, |
| 1906 | + in_keys=["observation", "hidden"], |
| 1907 | + out_keys=["intermediate", ("next", "hidden")], |
| 1908 | + ) |
| 1909 | + assert gru_module.set_recurrent_mode(False) is gru_module |
| 1910 | + assert not gru_module.set_recurrent_mode(False).recurrent_mode |
| 1911 | + assert gru_module.set_recurrent_mode(True) is not gru_module |
| 1912 | + assert gru_module.set_recurrent_mode(True).recurrent_mode |
| 1913 | + assert set(gru_module.set_recurrent_mode(True).parameters()) == set( |
| 1914 | + gru_module.parameters() |
| 1915 | + ) |
| 1916 | + |
| 1917 | + def test_noncontiguous(self): |
| 1918 | + gru_module = GRUModule( |
| 1919 | + input_size=3, |
| 1920 | + hidden_size=12, |
| 1921 | + batch_first=True, |
| 1922 | + in_keys=["bork", "h"], |
| 1923 | + out_keys=["dork", ("next", "h")], |
| 1924 | + ) |
| 1925 | + td = TensorDict( |
| 1926 | + { |
| 1927 | + "bork": torch.randn(3, 3), |
| 1928 | + "is_init": torch.zeros(3, 1, dtype=torch.bool), |
| 1929 | + }, |
| 1930 | + [3], |
| 1931 | + ) |
| 1932 | + padded = pad(td, [0, 5]) |
| 1933 | + gru_module(padded) |
| 1934 | + |
| 1935 | + @pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]]) |
| 1936 | + def test_singel_step(self, shape): |
| 1937 | + td = TensorDict( |
| 1938 | + { |
| 1939 | + "observation": torch.zeros(*shape, 3), |
| 1940 | + "is_init": torch.zeros(*shape, 1, dtype=torch.bool), |
| 1941 | + }, |
| 1942 | + shape, |
| 1943 | + ) |
| 1944 | + gru_module = GRUModule( |
| 1945 | + input_size=3, |
| 1946 | + hidden_size=12, |
| 1947 | + batch_first=True, |
| 1948 | + in_keys=["observation", "hidden"], |
| 1949 | + out_keys=["intermediate", ("next", "hidden")], |
| 1950 | + ) |
| 1951 | + td = gru_module(td) |
| 1952 | + td_next = step_mdp(td, keep_other=True) |
| 1953 | + td_next = gru_module(td_next) |
| 1954 | + |
| 1955 | + assert not torch.isclose(td_next["next", "hidden"], td["next", "hidden"]).any() |
| 1956 | + |
| 1957 | + @pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]]) |
| 1958 | + @pytest.mark.parametrize("t", [1, 10]) |
| 1959 | + def test_single_step_vs_multi(self, shape, t): |
| 1960 | + td = TensorDict( |
| 1961 | + { |
| 1962 | + "observation": torch.arange(t, dtype=torch.float32) |
| 1963 | + .unsqueeze(-1) |
| 1964 | + .expand(*shape, t, 3), |
| 1965 | + "is_init": torch.zeros(*shape, t, 1, dtype=torch.bool), |
| 1966 | + }, |
| 1967 | + [*shape, t], |
| 1968 | + ) |
| 1969 | + gru_module_ss = GRUModule( |
| 1970 | + input_size=3, |
| 1971 | + hidden_size=12, |
| 1972 | + batch_first=True, |
| 1973 | + in_keys=["observation", "hidden"], |
| 1974 | + out_keys=["intermediate", ("next", "hidden")], |
| 1975 | + ) |
| 1976 | + gru_module_ms = gru_module_ss.set_recurrent_mode() |
| 1977 | + gru_module_ms(td) |
| 1978 | + td_ss = TensorDict( |
| 1979 | + { |
| 1980 | + "observation": torch.zeros(*shape, 3), |
| 1981 | + "is_init": torch.zeros(*shape, 1, dtype=torch.bool), |
| 1982 | + }, |
| 1983 | + shape, |
| 1984 | + ) |
| 1985 | + for _t in range(t): |
| 1986 | + gru_module_ss(td_ss) |
| 1987 | + td_ss = step_mdp(td_ss, keep_other=True) |
| 1988 | + td_ss["observation"][:] = _t + 1 |
| 1989 | + torch.testing.assert_close(td_ss["hidden"], td["next", "hidden"][..., -1, :, :]) |
| 1990 | + |
| 1991 | + @pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]]) |
| 1992 | + def test_multi_consecutive(self, shape): |
| 1993 | + t = 20 |
| 1994 | + td = TensorDict( |
| 1995 | + { |
| 1996 | + "observation": torch.arange(t, dtype=torch.float32) |
| 1997 | + .unsqueeze(-1) |
| 1998 | + .expand(*shape, t, 3), |
| 1999 | + "is_init": torch.zeros(*shape, t, 1, dtype=torch.bool), |
| 2000 | + }, |
| 2001 | + [*shape, t], |
| 2002 | + ) |
| 2003 | + if shape: |
| 2004 | + td["is_init"][0, ..., 13, :] = True |
| 2005 | + else: |
| 2006 | + td["is_init"][13, :] = True |
| 2007 | + |
| 2008 | + gru_module_ss = GRUModule( |
| 2009 | + input_size=3, |
| 2010 | + hidden_size=12, |
| 2011 | + batch_first=True, |
| 2012 | + in_keys=["observation", "hidden"], |
| 2013 | + out_keys=["intermediate", ("next", "hidden")], |
| 2014 | + ) |
| 2015 | + gru_module_ms = gru_module_ss.set_recurrent_mode() |
| 2016 | + gru_module_ms(td) |
| 2017 | + td_ss = TensorDict( |
| 2018 | + { |
| 2019 | + "observation": torch.zeros(*shape, 3), |
| 2020 | + "is_init": torch.zeros(*shape, 1, dtype=torch.bool), |
| 2021 | + }, |
| 2022 | + shape, |
| 2023 | + ) |
| 2024 | + for _t in range(t): |
| 2025 | + td_ss["is_init"][:] = td["is_init"][..., _t, :] |
| 2026 | + gru_module_ss(td_ss) |
| 2027 | + td_ss = step_mdp(td_ss, keep_other=True) |
| 2028 | + td_ss["observation"][:] = _t + 1 |
| 2029 | + torch.testing.assert_close( |
| 2030 | + td_ss["intermediate"], td["intermediate"][..., -1, :] |
| 2031 | + ) |
| 2032 | + |
| 2033 | + def test_gru_parallel_env(self): |
| 2034 | + from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv |
| 2035 | + |
| 2036 | + # tests that hidden states are carried over with parallel envs |
| 2037 | + gru_module = GRUModule( |
| 2038 | + input_size=7, |
| 2039 | + hidden_size=12, |
| 2040 | + num_layers=2, |
| 2041 | + in_key="observation", |
| 2042 | + out_key="features", |
| 2043 | + ) |
| 2044 | + |
| 2045 | + def create_transformed_env(): |
| 2046 | + primer = gru_module.make_tensordict_primer() |
| 2047 | + env = DiscreteActionVecMockEnv(categorical_action_encoding=True) |
| 2048 | + env = TransformedEnv(env) |
| 2049 | + env.append_transform(InitTracker()) |
| 2050 | + env.append_transform(primer) |
| 2051 | + return env |
| 2052 | + |
| 2053 | + env = ParallelEnv( |
| 2054 | + create_env_fn=create_transformed_env, |
| 2055 | + num_workers=2, |
| 2056 | + ) |
| 2057 | + |
| 2058 | + mlp = TensorDictModule( |
| 2059 | + MLP( |
| 2060 | + in_features=12, |
| 2061 | + out_features=7, |
| 2062 | + num_cells=[], |
| 2063 | + ), |
| 2064 | + in_keys=["features"], |
| 2065 | + out_keys=["logits"], |
| 2066 | + ) |
| 2067 | + |
| 2068 | + actor_model = TensorDictSequential(gru_module, mlp) |
| 2069 | + |
| 2070 | + actor = ProbabilisticActor( |
| 2071 | + module=actor_model, |
| 2072 | + in_keys=["logits"], |
| 2073 | + out_keys=["action"], |
| 2074 | + distribution_class=torch.distributions.Categorical, |
| 2075 | + return_log_prob=True, |
| 2076 | + ) |
| 2077 | + for break_when_any_done in [False, True]: |
| 2078 | + data = env.rollout(10, actor, break_when_any_done=break_when_any_done) |
| 2079 | + assert (data.get("recurrent_state") != 0.0).any() |
| 2080 | + assert (data.get(("next", "recurrent_state")) != 0.0).all() |
| 2081 | + |
| 2082 | + |
1825 | 2083 | def test_safe_specs():
|
1826 | 2084 |
|
1827 | 2085 | out_key = ("a", "b")
|
|
0 commit comments