Skip to content

Commit d0d6bae

Browse files
authored
fix pre-commit (#351)
1 parent d43e70d commit d0d6bae

File tree

4 files changed

+10
-4
lines changed

4 files changed

+10
-4
lines changed

cleanrl/dqn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,8 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
214214
if global_step % args.target_network_frequency == 0:
215215
for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()):
216216
target_network_param.data.copy_(
217-
args.tau * q_network_param.data + (1. - args.tau) * target_network_param.data)
217+
args.tau * q_network_param.data + (1.0 - args.tau) * target_network_param.data
218+
)
218219

219220
if args.save_model:
220221
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"

cleanrl/dqn_atari.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,8 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
236236
if global_step % args.target_network_frequency == 0:
237237
for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()):
238238
target_network_param.data.copy_(
239-
args.tau * q_network_param.data + (1. - args.tau) * target_network_param.data)
239+
args.tau * q_network_param.data + (1.0 - args.tau) * target_network_param.data
240+
)
240241

241242
if args.save_model:
242243
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"

cleanrl/dqn_atari_jax.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,9 @@ def mse_loss(params):
264264

265265
# update target network
266266
if global_step % args.target_network_frequency == 0:
267-
q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, args.tau))
267+
q_state = q_state.replace(
268+
target_params=optax.incremental_update(q_state.params, q_state.target_params, args.tau)
269+
)
268270

269271
if args.save_model:
270272
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"

cleanrl/dqn_jax.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,9 @@ def mse_loss(params):
236236

237237
# update target network
238238
if global_step % args.target_network_frequency == 0:
239-
q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, args.tau))
239+
q_state = q_state.replace(
240+
target_params=optax.incremental_update(q_state.params, q_state.target_params, args.tau)
241+
)
240242

241243
if args.save_model:
242244
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"

0 commit comments

Comments
 (0)