Skip to content

Commit 9555dd9

Browse files
committed
Fix tests and warnings when running locally with a GPU (#2069)
* Fix test when GPU is available * Sort file list for consistent results * Ignore A2C warnings too
1 parent 593b2d9 commit 9555dd9

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

tests/test_save_load.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -758,16 +758,16 @@ def test_no_resource_warning(tmp_path):
758758

759759
# check that files are properly closed
760760
# Create a PPO agent and save it
761-
PPO("MlpPolicy", "CartPole-v1").save(tmp_path / "dqn_cartpole")
762-
PPO.load(tmp_path / "dqn_cartpole")
761+
PPO("MlpPolicy", "CartPole-v1", device="cpu").save(tmp_path / "dqn_cartpole")
762+
PPO.load(tmp_path / "dqn_cartpole", device="cpu")
763763

764-
PPO("MlpPolicy", "CartPole-v1").save(str(tmp_path / "dqn_cartpole"))
765-
PPO.load(str(tmp_path / "dqn_cartpole"))
764+
PPO("MlpPolicy", "CartPole-v1", device="cpu").save(str(tmp_path / "dqn_cartpole"))
765+
PPO.load(str(tmp_path / "dqn_cartpole"), device="cpu")
766766

767767
# Do the same but in memory, should not close the file
768768
with tempfile.TemporaryFile() as fp:
769-
PPO("MlpPolicy", "CartPole-v1").save(fp)
770-
PPO.load(fp)
769+
PPO("MlpPolicy", "CartPole-v1", device="cpu").save(fp)
770+
PPO.load(fp, device="cpu")
771771
assert not fp.closed
772772

773773
# Same but with replay buffer

0 commit comments

Comments
 (0)