Skip to content

Commit 8503378

Browse files
author
Vincent Moens
authored
[BugFix] Fix run_type_checks (#1570)
1 parent c00b62a commit 8503378

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

test/test_env.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2117,6 +2117,29 @@ def test_num_threads():
21172117
torch.set_num_threads(num_threads)
21182118

21192119

2120+
def test_run_type_checks():
2121+
env = ContinuousActionVecMockEnv()
2122+
env._run_type_checks = False
2123+
check_env_specs(env)
2124+
env._run_type_checks = True
2125+
check_env_specs(env)
2126+
env.output_spec.unlock_()
2127+
# check type check on done
2128+
env.output_spec["full_done_spec", "done"].dtype = torch.int
2129+
with pytest.raises(TypeError, match="expected done.dtype to"):
2130+
check_env_specs(env)
2131+
env.output_spec["full_done_spec", "done"].dtype = torch.bool
2132+
# check type check on reward
2133+
env.output_spec["full_reward_spec", "reward"].dtype = torch.int
2134+
with pytest.raises(TypeError, match="expected"):
2135+
check_env_specs(env)
2136+
env.output_spec["full_reward_spec", "reward"].dtype = torch.float
2137+
# check type check on obs
2138+
env.output_spec["full_observation_spec", "observation"].dtype = torch.float16
2139+
with pytest.raises(TypeError):
2140+
check_env_specs(env)
2141+
2142+
21202143
if __name__ == "__main__":
21212144
args, unknown = argparse.ArgumentParser().parse_known_args()
21222145
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/envs/common.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,10 +1193,9 @@ def _step_proc_data(self, next_tensordict_out):
11931193
next_tensordict_out.set(done_key, done)
11941194

11951195
if self.run_type_checks:
1196-
# TODO: check these errors
1197-
for key in self._select_observation_keys(next_tensordict_out):
1196+
for key, spec in self.observation_spec.items():
11981197
obs = next_tensordict_out.get(key)
1199-
self.observation_spec.type_check(obs, key)
1198+
spec.type_check(obs)
12001199

12011200
for reward_key in self.reward_keys:
12021201
if (
@@ -1213,10 +1212,10 @@ def _step_proc_data(self, next_tensordict_out):
12131212
for done_key in self.done_keys:
12141213
if (
12151214
next_tensordict_out.get(done_key).dtype
1216-
is not self.output_spec["full_done_spec"].get(done_key).dtype
1215+
is not self.output_spec["full_done_spec", done_key].dtype
12171216
):
12181217
raise TypeError(
1219-
f"expected done.dtype to be torch.bool but got {next_tensordict_out.get(done_key).dtype}"
1218+
f"expected done.dtype to be {self.output_spec['full_done_spec', done_key].dtype} but got {next_tensordict_out.get(done_key).dtype}"
12201219
)
12211220
return next_tensordict_out
12221221

0 commit comments

Comments
 (0)