Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 39159f4

Browse files
Format .jl files (#82)
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 87a63f0 commit 39159f4

File tree

3 files changed

+23
-12
lines changed

3 files changed

+23
-12
lines changed

src/algorithms/policy_gradient/ddpg.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ function DDPGPolicy(;
8888
act_noise,
8989
step,
9090
rng,
91-
0.f0,0.f0,
91+
0.f0,
92+
0.f0,
9293
)
9394
end
9495

@@ -138,7 +139,7 @@ function RLBase.update!(p::DDPGPolicy, traj::CircularCompactSARTSATrajectory)
138139
gs1 = gradient(Flux.params(C)) do
139140
q = C(vcat(s, a)) |> vec
140141
loss = mean((y .- q) .^ 2)
141-
ignore() do
142+
ignore() do
142143
p.critic_loss = loss
143144
end
144145
loss
@@ -148,7 +149,7 @@ function RLBase.update!(p::DDPGPolicy, traj::CircularCompactSARTSATrajectory)
148149

149150
gs2 = gradient(Flux.params(A)) do
150151
loss = -mean(C(vcat(s, A(s))))
151-
ignore() do
152+
ignore() do
152153
p.actor_loss = loss
153154
end
154155
loss

src/experiments/rl_envs.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ function RLCore.Experiment(
511511
N_ENV = 16
512512
UPDATE_FREQ = 10
513513
env = MultiThreadEnv([
514-
CartPoleEnv(; T = Float32, rng = MersenneTwister(hash(seed + i))) for i = 1:N_ENV
514+
CartPoleEnv(; T = Float32, rng = MersenneTwister(hash(seed + i))) for i in 1:N_ENV
515515
])
516516
ns, na = length(get_state(env[1])), length(get_actions(env[1]))
517517
RLBase.reset!(env, is_force = true)
@@ -599,7 +599,7 @@ function RLCore.Experiment(
599599
N_ENV = 16
600600
UPDATE_FREQ = 10
601601
env = MultiThreadEnv([
602-
CartPoleEnv(; T = Float32, rng = MersenneTwister(hash(seed + i))) for i = 1:N_ENV
602+
CartPoleEnv(; T = Float32, rng = MersenneTwister(hash(seed + i))) for i in 1:N_ENV
603603
])
604604
ns, na = length(get_state(env[1])), length(get_actions(env[1]))
605605
RLBase.reset!(env, is_force = true)
@@ -800,7 +800,7 @@ function RLCore.Experiment(
800800
N_ENV = 8
801801
UPDATE_FREQ = 16
802802
env = MultiThreadEnv([
803-
CartPoleEnv(; T = Float32, rng = MersenneTwister(hash(seed + i))) for i = 1:N_ENV
803+
CartPoleEnv(; T = Float32, rng = MersenneTwister(hash(seed + i))) for i in 1:N_ENV
804804
])
805805
ns, na = length(get_state(env[1])), length(get_actions(env[1]))
806806
RLBase.reset!(env, is_force = true)

test/runtests.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ using Random
1717
Val(method),
1818
Val(:CartPole),
1919
nothing;
20-
save_dir=joinpath(dir, "CartPole", string(method)),
20+
save_dir = joinpath(dir, "CartPole", string(method)),
2121
))
2222
@info "stats for $method" avg_reward = mean(res.hook[1].rewards) avg_fps =
2323
1 / mean(res.hook[2].times)
@@ -29,22 +29,32 @@ using Random
2929
Val(method),
3030
Val(:MountainCar),
3131
nothing;
32-
save_dir=joinpath(dir, "MountainCar", string(method)),
32+
save_dir = joinpath(dir, "MountainCar", string(method)),
3333
))
3434
@info "stats for $method" avg_reward = mean(res.hook[1].rewards) avg_fps =
3535
1 / mean(res.hook[2].times)
3636
end
3737

3838
for method in (:A2C, :A2CGAE, :PPO)
39-
res = run(Experiment(Val(:JuliaRL), Val(method), Val(:CartPole), nothing;
40-
save_dir=joinpath(dir, "CartPole", string(method)),))
39+
res = run(Experiment(
40+
Val(:JuliaRL),
41+
Val(method),
42+
Val(:CartPole),
43+
nothing;
44+
save_dir = joinpath(dir, "CartPole", string(method)),
45+
))
4146
@info "stats for $method" avg_reward =
4247
mean(Iterators.flatten(res.hook[1].rewards))
4348
end
4449

4550
for method in (:DDPG, :SAC)
46-
res = run(Experiment(Val(:JuliaRL), Val(method), Val(:Pendulum), nothing;
47-
save_dir=joinpath(dir, "Pendulum", string(method)),))
51+
res = run(Experiment(
52+
Val(:JuliaRL),
53+
Val(method),
54+
Val(:Pendulum),
55+
nothing;
56+
save_dir = joinpath(dir, "Pendulum", string(method)),
57+
))
4858
@info "stats for $method" avg_reward =
4959
mean(Iterators.flatten(res.hook[1].rewards))
5060
end

0 commit comments

Comments
 (0)