From 93d2c8bab99c689e11414416f979bf58eb3f6231 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Wed, 15 Nov 2023 14:08:03 +0000 Subject: [PATCH] Use the training `end_e` as the `evaluation(..., epsilon=end_e)` --- cleanrl/c51_atari.py | 2 +- cleanrl/c51_atari_jax.py | 2 +- cleanrl/dqn_atari.py | 2 +- cleanrl/dqn_atari_jax.py | 2 +- cleanrl/qdagger_dqn_atari_impalacnn.py | 6 +++--- cleanrl/qdagger_dqn_atari_jax_impalacnn.py | 6 +++--- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/cleanrl/c51_atari.py b/cleanrl/c51_atari.py index 8e47bacc5..074af8dbf 100755 --- a/cleanrl/c51_atari.py +++ b/cleanrl/c51_atari.py @@ -303,7 +303,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): run_name=f"{run_name}-eval", Model=QNetwork, device=device, - epsilon=0.05, + epsilon=args.end_e, ) for idx, episodic_return in enumerate(episodic_returns): writer.add_scalar("eval/episodic_return", episodic_return, idx) diff --git a/cleanrl/c51_atari_jax.py b/cleanrl/c51_atari_jax.py index 93c436ec5..f2c544d37 100644 --- a/cleanrl/c51_atari_jax.py +++ b/cleanrl/c51_atari_jax.py @@ -343,7 +343,7 @@ def get_action(q_state, obs): eval_episodes=10, run_name=f"{run_name}-eval", Model=QNetwork, - epsilon=0.05, + epsilon=args.end_e, ) for idx, episodic_return in enumerate(episodic_returns): writer.add_scalar("eval/episodic_return", episodic_return, idx) diff --git a/cleanrl/dqn_atari.py b/cleanrl/dqn_atari.py index a4c3df339..3b7f63c78 100644 --- a/cleanrl/dqn_atari.py +++ b/cleanrl/dqn_atari.py @@ -272,7 +272,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): run_name=f"{run_name}-eval", Model=QNetwork, device=device, - epsilon=0.05, + epsilon=args.end_e, ) for idx, episodic_return in enumerate(episodic_returns): writer.add_scalar("eval/episodic_return", episodic_return, idx) diff --git a/cleanrl/dqn_atari_jax.py b/cleanrl/dqn_atari_jax.py index 5f74d57a9..a850afd21 100644 --- a/cleanrl/dqn_atari_jax.py +++ b/cleanrl/dqn_atari_jax.py @@ -301,7 +301,7 @@ def mse_loss(params): eval_episodes=10, run_name=f"{run_name}-eval", Model=QNetwork, - epsilon=0.05, + epsilon=args.end_e, ) for idx, episodic_return in enumerate(episodic_returns): writer.add_scalar("eval/episodic_return", episodic_return, idx) diff --git a/cleanrl/qdagger_dqn_atari_impalacnn.py b/cleanrl/qdagger_dqn_atari_impalacnn.py index ef7922a91..92247aa82 100644 --- a/cleanrl/qdagger_dqn_atari_impalacnn.py +++ b/cleanrl/qdagger_dqn_atari_impalacnn.py @@ -265,7 +265,7 @@ def kl_divergence_with_logits(target_logits, prediction_logits): eval_episodes=args.teacher_eval_episodes, run_name=f"{run_name}-teacher-eval", Model=TeacherModel, - epsilon=0.05, + epsilon=args.end_e, capture_video=False, ) writer.add_scalar("charts/teacher/avg_episodic_return", np.mean(teacher_episodic_returns), 0) @@ -344,7 +344,7 @@ def kl_divergence_with_logits(target_logits, prediction_logits): run_name=f"{run_name}-eval", Model=QNetwork, device=device, - epsilon=0.05, + epsilon=args.end_e, ) print(episodic_returns) writer.add_scalar("charts/offline/avg_episodic_return", np.mean(episodic_returns), global_step) @@ -461,7 +461,7 @@ def kl_divergence_with_logits(target_logits, prediction_logits): run_name=f"{run_name}-eval", Model=QNetwork, device=device, - epsilon=0.05, + epsilon=args.end_e, ) for idx, episodic_return in enumerate(episodic_returns): writer.add_scalar("eval/episodic_return", episodic_return, idx) diff --git a/cleanrl/qdagger_dqn_atari_jax_impalacnn.py b/cleanrl/qdagger_dqn_atari_jax_impalacnn.py index ce55baf4c..56e9c764f 100644 --- a/cleanrl/qdagger_dqn_atari_jax_impalacnn.py +++ b/cleanrl/qdagger_dqn_atari_jax_impalacnn.py @@ -264,7 +264,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): eval_episodes=args.teacher_eval_episodes, run_name=f"{run_name}-teacher-eval", Model=TeacherModel, - epsilon=0.05, + epsilon=args.end_e, capture_video=False, ) writer.add_scalar("charts/teacher/avg_episodic_return", np.mean(teacher_episodic_returns), 0) @@ -363,7 +363,7 @@ def loss(params, td_target, teacher_q_values, distill_coeff): eval_episodes=10, run_name=f"{run_name}-eval", Model=QNetwork, - epsilon=0.05, + epsilon=args.end_e, ) print(episodic_returns) writer.add_scalar("charts/offline/avg_episodic_return", np.mean(episodic_returns), global_step) @@ -471,7 +471,7 @@ def loss(params, td_target, teacher_q_values, distill_coeff): eval_episodes=10, run_name=f"{run_name}-eval", Model=QNetwork, - epsilon=0.05, + epsilon=args.end_e, ) for idx, episodic_return in enumerate(episodic_returns): writer.add_scalar("eval/episodic_return", episodic_return, idx)