|
22 | 22 | MockSerialEnv,
|
23 | 23 | )
|
24 | 24 | from packaging import version
|
25 |
| -from scipy.stats import chisquare |
26 | 25 | from torch import nn
|
27 | 26 | from torchrl.data.tensor_specs import (
|
28 |
| - BoundedTensorSpec, |
29 |
| - DiscreteTensorSpec, |
30 |
| - MultOneHotDiscreteTensorSpec, |
31 |
| - NdBoundedTensorSpec, |
32 | 27 | OneHotDiscreteTensorSpec,
|
33 | 28 | UnboundedContinuousTensorSpec,
|
34 | 29 | )
|
@@ -917,132 +912,6 @@ def env_fn2(seed):
|
917 | 912 | env2.close()
|
918 | 913 |
|
919 | 914 |
|
920 |
| -class TestSpec: |
921 |
| - @pytest.mark.parametrize( |
922 |
| - "action_spec_cls", [OneHotDiscreteTensorSpec, DiscreteTensorSpec] |
923 |
| - ) |
924 |
| - def test_discrete_action_spec_reconstruct(self, action_spec_cls): |
925 |
| - torch.manual_seed(0) |
926 |
| - action_spec = action_spec_cls(10) |
927 |
| - |
928 |
| - actions_tensors = [action_spec.rand() for _ in range(10)] |
929 |
| - actions_numpy = [action_spec.to_numpy(a) for a in actions_tensors] |
930 |
| - actions_tensors_2 = [action_spec.encode(a) for a in actions_numpy] |
931 |
| - assert all( |
932 |
| - [(a1 == a2).all() for a1, a2 in zip(actions_tensors, actions_tensors_2)] |
933 |
| - ) |
934 |
| - |
935 |
| - actions_numpy = [int(np.random.randint(0, 10, (1,))) for a in actions_tensors] |
936 |
| - actions_tensors = [action_spec.encode(a) for a in actions_numpy] |
937 |
| - actions_numpy_2 = [action_spec.to_numpy(a) for a in actions_tensors] |
938 |
| - assert all([(a1 == a2) for a1, a2 in zip(actions_numpy, actions_numpy_2)]) |
939 |
| - |
940 |
| - def test_mult_discrete_action_spec_reconstruct(self): |
941 |
| - torch.manual_seed(0) |
942 |
| - action_spec = MultOneHotDiscreteTensorSpec((10, 5)) |
943 |
| - |
944 |
| - actions_tensors = [action_spec.rand() for _ in range(10)] |
945 |
| - actions_numpy = [action_spec.to_numpy(a) for a in actions_tensors] |
946 |
| - actions_tensors_2 = [action_spec.encode(a) for a in actions_numpy] |
947 |
| - assert all( |
948 |
| - [(a1 == a2).all() for a1, a2 in zip(actions_tensors, actions_tensors_2)] |
949 |
| - ) |
950 |
| - |
951 |
| - actions_numpy = [ |
952 |
| - np.concatenate( |
953 |
| - [np.random.randint(0, 10, (1,)), np.random.randint(0, 5, (1,))], 0 |
954 |
| - ) |
955 |
| - for a in actions_tensors |
956 |
| - ] |
957 |
| - actions_tensors = [action_spec.encode(a) for a in actions_numpy] |
958 |
| - actions_numpy_2 = [action_spec.to_numpy(a) for a in actions_tensors] |
959 |
| - assert all([(a1 == a2).all() for a1, a2 in zip(actions_numpy, actions_numpy_2)]) |
960 |
| - |
961 |
| - def test_one_hot_discrete_action_spec_rand(self): |
962 |
| - torch.manual_seed(0) |
963 |
| - action_spec = OneHotDiscreteTensorSpec(10) |
964 |
| - |
965 |
| - sample = torch.stack([action_spec.rand() for _ in range(10000)], 0) |
966 |
| - |
967 |
| - sample_list = sample.argmax(-1) |
968 |
| - sample_list = list([sum(sample_list == i).item() for i in range(10)]) |
969 |
| - assert chisquare(sample_list).pvalue > 0.1 |
970 |
| - |
971 |
| - sample = action_spec.to_numpy(sample) |
972 |
| - sample = [sum(sample == i) for i in range(10)] |
973 |
| - assert chisquare(sample).pvalue > 0.1 |
974 |
| - |
975 |
| - def test_categorical_action_spec_rand(self): |
976 |
| - torch.manual_seed(0) |
977 |
| - action_spec = DiscreteTensorSpec(10) |
978 |
| - |
979 |
| - sample = torch.stack([action_spec.rand() for _ in range(10000)], 0) |
980 |
| - |
981 |
| - sample_list = sample[:, 0] |
982 |
| - sample_list = list([sum(sample_list == i).item() for i in range(10)]) |
983 |
| - assert chisquare(sample_list).pvalue > 0.1 |
984 |
| - |
985 |
| - sample = action_spec.to_numpy(sample) |
986 |
| - sample = [sum(sample == i) for i in range(10)] |
987 |
| - assert chisquare(sample).pvalue > 0.1 |
988 |
| - |
989 |
| - def test_mult_discrete_action_spec_rand(self): |
990 |
| - torch.manual_seed(0) |
991 |
| - ns = (10, 5) |
992 |
| - N = 100000 |
993 |
| - action_spec = MultOneHotDiscreteTensorSpec((10, 5)) |
994 |
| - |
995 |
| - actions_tensors = [action_spec.rand() for _ in range(10)] |
996 |
| - actions_numpy = [action_spec.to_numpy(a) for a in actions_tensors] |
997 |
| - actions_tensors_2 = [action_spec.encode(a) for a in actions_numpy] |
998 |
| - assert all( |
999 |
| - [(a1 == a2).all() for a1, a2 in zip(actions_tensors, actions_tensors_2)] |
1000 |
| - ) |
1001 |
| - |
1002 |
| - sample = np.stack( |
1003 |
| - [action_spec.to_numpy(action_spec.rand()) for _ in range(N)], 0 |
1004 |
| - ) |
1005 |
| - assert sample.shape[0] == N |
1006 |
| - assert sample.shape[1] == 2 |
1007 |
| - assert sample.ndim == 2, f"found shape: {sample.shape}" |
1008 |
| - |
1009 |
| - sample0 = sample[:, 0] |
1010 |
| - sample_list = list([sum(sample0 == i) for i in range(ns[0])]) |
1011 |
| - assert chisquare(sample_list).pvalue > 0.1 |
1012 |
| - |
1013 |
| - sample1 = sample[:, 1] |
1014 |
| - sample_list = list([sum(sample1 == i) for i in range(ns[1])]) |
1015 |
| - assert chisquare(sample_list).pvalue > 0.1 |
1016 |
| - |
1017 |
| - def test_categorical_action_spec_encode(self): |
1018 |
| - action_spec = DiscreteTensorSpec(10) |
1019 |
| - |
1020 |
| - projected = action_spec.project( |
1021 |
| - torch.tensor([-100, -1, 0, 1, 9, 10, 100], dtype=torch.long) |
1022 |
| - ) |
1023 |
| - assert ( |
1024 |
| - projected == torch.tensor([0, 0, 0, 1, 9, 9, 9], dtype=torch.long) |
1025 |
| - ).all() |
1026 |
| - |
1027 |
| - projected = action_spec.project( |
1028 |
| - torch.tensor([-100.0, -1.0, 0.0, 1.0, 9.0, 10.0, 100.0], dtype=torch.float) |
1029 |
| - ) |
1030 |
| - assert ( |
1031 |
| - projected == torch.tensor([0, 0, 0, 1, 9, 9, 9], dtype=torch.long) |
1032 |
| - ).all() |
1033 |
| - |
1034 |
| - def test_bounded_rand(self): |
1035 |
| - spec = BoundedTensorSpec(-3, 3) |
1036 |
| - sample = torch.stack([spec.rand() for _ in range(100)]) |
1037 |
| - assert (-3 <= sample).all() and (3 >= sample).all() |
1038 |
| - |
1039 |
| - def test_ndbounded_shape(self): |
1040 |
| - spec = NdBoundedTensorSpec(-3, 3 * torch.ones(10, 5), shape=[10, 5]) |
1041 |
| - sample = torch.stack([spec.rand() for _ in range(100)], 0) |
1042 |
| - assert (-3 <= sample).all() and (3 >= sample).all() |
1043 |
| - assert sample.shape == torch.Size([100, 10, 5]) |
1044 |
| - |
1045 |
| - |
1046 | 915 | @pytest.mark.skipif(not _has_gym, reason="no gym")
|
1047 | 916 | def test_seed():
|
1048 | 917 | torch.manual_seed(0)
|
|
0 commit comments