Skip to content

Commit 278e9be

Browse files
authored
[Test] Check dtypes of envs (#666)
1 parent ecedcf1 commit 278e9be

File tree

5 files changed

+311
-249
lines changed

5 files changed

+311
-249
lines changed

test/_utils_internal.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import pytest
1313
import torch.cuda
1414
from torchrl._utils import seed_generator
15+
from torchrl.data import CompositeSpec
16+
from torchrl.data.tensordict.tensordict import TensorDictBase
1517
from torchrl.envs import EnvBase
1618

1719

@@ -54,6 +56,34 @@ def _test_fake_tensordict(env: EnvBase):
5456
for key in keys2:
5557
assert fake_tensordict[key].shape == real_tensordict[key].shape
5658

59+
# test dtypes
60+
for key, value in real_tensordict.unflatten_keys(".").items():
61+
_check_dtype(key, value, env.observation_spec, env.input_spec)
62+
63+
64+
def _check_dtype(key, value, obs_spec, input_spec):
65+
if key.startswith("next_"):
66+
return
67+
if isinstance(value, TensorDictBase):
68+
for _key, _value in value.items():
69+
if isinstance(obs_spec, CompositeSpec) and "next_" + key in obs_spec.keys():
70+
_check_dtype(_key, _value, obs_spec["next_" + key], input_spec=None)
71+
elif isinstance(input_spec, CompositeSpec) and key in input_spec.keys():
72+
_check_dtype(_key, _value, obs_spec=None, input_spec=input_spec[key])
73+
else:
74+
raise KeyError(f"key '{_key}' is unknown.")
75+
else:
76+
if obs_spec is not None and "next_" + key in obs_spec.keys():
77+
assert (
78+
obs_spec["next_" + key].dtype is value.dtype
79+
), f"{obs_spec['next_' + key].dtype} vs {value.dtype} for {key}"
80+
elif input_spec is not None and key in input_spec.keys():
81+
assert (
82+
input_spec[key].dtype is value.dtype
83+
), f"{input_spec[key].dtype} vs {value.dtype} for {key}"
84+
else:
85+
assert key in {"done", "reward"}, (key, obs_spec, input_spec)
86+
5787

5888
# Decorator to retry upon certain Exceptions.
5989
def retry(ExceptionToCheck, tries=3, delay=3, skip_after_retries=False):

test/test_env.py

Lines changed: 0 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,8 @@
2222
MockSerialEnv,
2323
)
2424
from packaging import version
25-
from scipy.stats import chisquare
2625
from torch import nn
2726
from torchrl.data.tensor_specs import (
28-
BoundedTensorSpec,
29-
DiscreteTensorSpec,
30-
MultOneHotDiscreteTensorSpec,
31-
NdBoundedTensorSpec,
3227
OneHotDiscreteTensorSpec,
3328
UnboundedContinuousTensorSpec,
3429
)
@@ -917,132 +912,6 @@ def env_fn2(seed):
917912
env2.close()
918913

919914

920-
class TestSpec:
921-
@pytest.mark.parametrize(
922-
"action_spec_cls", [OneHotDiscreteTensorSpec, DiscreteTensorSpec]
923-
)
924-
def test_discrete_action_spec_reconstruct(self, action_spec_cls):
925-
torch.manual_seed(0)
926-
action_spec = action_spec_cls(10)
927-
928-
actions_tensors = [action_spec.rand() for _ in range(10)]
929-
actions_numpy = [action_spec.to_numpy(a) for a in actions_tensors]
930-
actions_tensors_2 = [action_spec.encode(a) for a in actions_numpy]
931-
assert all(
932-
[(a1 == a2).all() for a1, a2 in zip(actions_tensors, actions_tensors_2)]
933-
)
934-
935-
actions_numpy = [int(np.random.randint(0, 10, (1,))) for a in actions_tensors]
936-
actions_tensors = [action_spec.encode(a) for a in actions_numpy]
937-
actions_numpy_2 = [action_spec.to_numpy(a) for a in actions_tensors]
938-
assert all([(a1 == a2) for a1, a2 in zip(actions_numpy, actions_numpy_2)])
939-
940-
def test_mult_discrete_action_spec_reconstruct(self):
941-
torch.manual_seed(0)
942-
action_spec = MultOneHotDiscreteTensorSpec((10, 5))
943-
944-
actions_tensors = [action_spec.rand() for _ in range(10)]
945-
actions_numpy = [action_spec.to_numpy(a) for a in actions_tensors]
946-
actions_tensors_2 = [action_spec.encode(a) for a in actions_numpy]
947-
assert all(
948-
[(a1 == a2).all() for a1, a2 in zip(actions_tensors, actions_tensors_2)]
949-
)
950-
951-
actions_numpy = [
952-
np.concatenate(
953-
[np.random.randint(0, 10, (1,)), np.random.randint(0, 5, (1,))], 0
954-
)
955-
for a in actions_tensors
956-
]
957-
actions_tensors = [action_spec.encode(a) for a in actions_numpy]
958-
actions_numpy_2 = [action_spec.to_numpy(a) for a in actions_tensors]
959-
assert all([(a1 == a2).all() for a1, a2 in zip(actions_numpy, actions_numpy_2)])
960-
961-
def test_one_hot_discrete_action_spec_rand(self):
962-
torch.manual_seed(0)
963-
action_spec = OneHotDiscreteTensorSpec(10)
964-
965-
sample = torch.stack([action_spec.rand() for _ in range(10000)], 0)
966-
967-
sample_list = sample.argmax(-1)
968-
sample_list = list([sum(sample_list == i).item() for i in range(10)])
969-
assert chisquare(sample_list).pvalue > 0.1
970-
971-
sample = action_spec.to_numpy(sample)
972-
sample = [sum(sample == i) for i in range(10)]
973-
assert chisquare(sample).pvalue > 0.1
974-
975-
def test_categorical_action_spec_rand(self):
976-
torch.manual_seed(0)
977-
action_spec = DiscreteTensorSpec(10)
978-
979-
sample = torch.stack([action_spec.rand() for _ in range(10000)], 0)
980-
981-
sample_list = sample[:, 0]
982-
sample_list = list([sum(sample_list == i).item() for i in range(10)])
983-
assert chisquare(sample_list).pvalue > 0.1
984-
985-
sample = action_spec.to_numpy(sample)
986-
sample = [sum(sample == i) for i in range(10)]
987-
assert chisquare(sample).pvalue > 0.1
988-
989-
def test_mult_discrete_action_spec_rand(self):
990-
torch.manual_seed(0)
991-
ns = (10, 5)
992-
N = 100000
993-
action_spec = MultOneHotDiscreteTensorSpec((10, 5))
994-
995-
actions_tensors = [action_spec.rand() for _ in range(10)]
996-
actions_numpy = [action_spec.to_numpy(a) for a in actions_tensors]
997-
actions_tensors_2 = [action_spec.encode(a) for a in actions_numpy]
998-
assert all(
999-
[(a1 == a2).all() for a1, a2 in zip(actions_tensors, actions_tensors_2)]
1000-
)
1001-
1002-
sample = np.stack(
1003-
[action_spec.to_numpy(action_spec.rand()) for _ in range(N)], 0
1004-
)
1005-
assert sample.shape[0] == N
1006-
assert sample.shape[1] == 2
1007-
assert sample.ndim == 2, f"found shape: {sample.shape}"
1008-
1009-
sample0 = sample[:, 0]
1010-
sample_list = list([sum(sample0 == i) for i in range(ns[0])])
1011-
assert chisquare(sample_list).pvalue > 0.1
1012-
1013-
sample1 = sample[:, 1]
1014-
sample_list = list([sum(sample1 == i) for i in range(ns[1])])
1015-
assert chisquare(sample_list).pvalue > 0.1
1016-
1017-
def test_categorical_action_spec_encode(self):
1018-
action_spec = DiscreteTensorSpec(10)
1019-
1020-
projected = action_spec.project(
1021-
torch.tensor([-100, -1, 0, 1, 9, 10, 100], dtype=torch.long)
1022-
)
1023-
assert (
1024-
projected == torch.tensor([0, 0, 0, 1, 9, 9, 9], dtype=torch.long)
1025-
).all()
1026-
1027-
projected = action_spec.project(
1028-
torch.tensor([-100.0, -1.0, 0.0, 1.0, 9.0, 10.0, 100.0], dtype=torch.float)
1029-
)
1030-
assert (
1031-
projected == torch.tensor([0, 0, 0, 1, 9, 9, 9], dtype=torch.long)
1032-
).all()
1033-
1034-
def test_bounded_rand(self):
1035-
spec = BoundedTensorSpec(-3, 3)
1036-
sample = torch.stack([spec.rand() for _ in range(100)])
1037-
assert (-3 <= sample).all() and (3 >= sample).all()
1038-
1039-
def test_ndbounded_shape(self):
1040-
spec = NdBoundedTensorSpec(-3, 3 * torch.ones(10, 5), shape=[10, 5])
1041-
sample = torch.stack([spec.rand() for _ in range(100)], 0)
1042-
assert (-3 <= sample).all() and (3 >= sample).all()
1043-
assert sample.shape == torch.Size([100, 10, 5])
1044-
1045-
1046915
@pytest.mark.skipif(not _has_gym, reason="no gym")
1047916
def test_seed():
1048917
torch.manual_seed(0)

0 commit comments

Comments
 (0)