Skip to content

Fix dqn model evals #381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,16 @@ jobs:
- name: Install jax
if: runner.os == 'Linux' || runner.os == 'macOS'
run: poetry install -E "pytest atari jax"
- name: Run core tests with jax
if: runner.os == 'Linux' || runner.os == 'macOS'
run: poetry run pytest tests/test_atari_jax.py
- name: Run gymnasium migration dependencies
run: poetry run pip install "stable_baselines3==2.0.0a1" "gymnasium[atari,accept-rom-license]==0.28.1" "ale-py==0.8.1"
- name: Run gymnasium tests
run: poetry run pytest tests/test_atari_gymnasium.py
- name: Run core tests with jax
- name: Run gymnasium tests with jax
if: runner.os == 'Linux' || runner.os == 'macOS'
run: poetry run pytest tests/test_atari_jax.py
run: poetry run pytest tests/test_atari_jax_gymnasium.py

test-pybullet-envs:
strategy:
Expand Down Expand Up @@ -348,7 +351,7 @@ jobs:

# envpool tests
- name: Install envpool dependencies
run: poetry install -E "pytest envpool jax"
run: poetry install -E "pytest envpool jax ppo_atari_envpool_xla_jax_scan"
- name: Downgrade setuptools
run: poetry run pip install setuptools==59.5.0
- name: Run envpool tests
Expand Down
12 changes: 7 additions & 5 deletions cleanrl_utils/evals/dqn_eval.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import random
from typing import Callable

import gym
import gymnasium as gym
import numpy as np
import torch

Expand All @@ -22,17 +22,19 @@ def evaluate(
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

obs = envs.reset()
obs, _ = envs.reset()
episodic_returns = []
while len(episodic_returns) < eval_episodes:
if random.random() < epsilon:
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
else:
q_values = model(torch.Tensor(obs).to(device))
actions = torch.argmax(q_values, dim=1).cpu().numpy()
next_obs, _, _, infos = envs.step(actions)
for info in infos:
if "episode" in info.keys():
next_obs, _, _, _, infos = envs.step(actions)
if "final_info" in infos:
for info in infos["final_info"]:
if "episode" not in info:
continue
print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}")
episodic_returns += [info["episode"]["r"]]
obs = next_obs
Expand Down
12 changes: 7 additions & 5 deletions cleanrl_utils/evals/dqn_jax_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import flax
import flax.linen as nn
import gym
import gymnasium as gym
import jax
import numpy as np

Expand All @@ -20,7 +20,7 @@ def evaluate(
seed=1,
):
envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, 0, capture_video, run_name)])
obs = envs.reset()
obs, _ = envs.reset()
model = Model(action_dim=envs.single_action_space.n)
q_key = jax.random.PRNGKey(seed)
params = model.init(q_key, obs)
Expand All @@ -36,9 +36,11 @@ def evaluate(
q_values = model.apply(params, obs)
actions = q_values.argmax(axis=-1)
actions = jax.device_get(actions)
next_obs, _, _, infos = envs.step(actions)
for info in infos:
if "episode" in info.keys():
next_obs, _, _, _, infos = envs.step(actions)
if "final_info" in infos:
for info in infos["final_info"]:
if "episode" not in info:
continue
print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}")
episodic_returns += [info["episode"]["r"]]
obs = next_obs
Expand Down
8 changes: 8 additions & 0 deletions tests/test_atari_gymnasium.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,11 @@ def test_dqn():
shell=True,
check=True,
)


def test_dqn_eval():
subprocess.run(
"python cleanrl/dqn_atari.py --save-model True --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4",
shell=True,
check=True,
)
12 changes: 10 additions & 2 deletions tests/test_atari_jax.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import subprocess


def test_dqn_jax():
def test_c51_jax():
subprocess.run(
"python cleanrl/dqn_atari_jax.py --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4",
"python cleanrl/c51_atari_jax.py --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4",
shell=True,
check=True,
)


def test_c51_jax_eval():
subprocess.run(
"python cleanrl/c51_atari_jax.py --save-model True --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4",
shell=True,
check=True,
)
17 changes: 17 additions & 0 deletions tests/test_atari_jax_gymnasium.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import subprocess


def test_dqn_jax():
subprocess.run(
"python cleanrl/dqn_atari_jax.py --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4",
shell=True,
check=True,
)


def test_dqn_jax_eval():
subprocess.run(
"python cleanrl/dqn_atari_jax.py --save-model True --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4",
shell=True,
check=True,
)
8 changes: 8 additions & 0 deletions tests/test_envpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,11 @@ def test_ppo_atari_envpool_xla_jax_scan():
shell=True,
check=True,
)


def test_ppo_atari_envpool_xla_jax_scan_eval():
subprocess.run(
"python cleanrl/ppo_atari_envpool_xla_jax_scan.py --save-model True --num-envs 8 --num-steps 6 --update-epochs 1 --num-minibatches 1 --total-timesteps 256",
shell=True,
check=True,
)