Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 7829fa5

Browse files
Format .jl files (#88)
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 44e3358 commit 7829fa5

File tree

11 files changed

+46
-39
lines changed

11 files changed

+46
-39
lines changed

src/algorithms/cfr/tabular_cfr.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ function cfr!(nodes, env, player, reach_probs, chance_player_reach_prob, ratio)
104104
v = 0.0
105105
node = nodes[get_state(env)]
106106
legal_actions = get_legal_actions(env)
107-
U = player == get_current_player(env) ?
107+
U =
108+
player == get_current_player(env) ?
108109
Vector{Float64}(undef, length(legal_actions)) : nothing
109110

110111
for (i, action) in enumerate(legal_actions)

src/algorithms/dqns/dqn.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ function DQNLearner(;
6969
target_update_freq,
7070
update_step,
7171
rng,
72-
0.f0,
72+
0.0f0,
7373
)
7474
end
7575

src/algorithms/dqns/iqn.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ function IQNLearner(;
120120
β_priority = 0.5f0,
121121
rng = Random.GLOBAL_RNG,
122122
device_rng = CUDA.CURAND.RNG(),
123-
loss = 0.f0,
123+
loss = 0.0f0,
124124
)
125125
copyto!(approximator, target_approximator) # force sync
126126
if device(approximator) !== device(device_rng)
@@ -200,7 +200,7 @@ function RLBase.update!(learner::IQNLearner, batch::NamedTuple)
200200
is_use_PER = !isnothing(batch.priorities) # is use Prioritized Experience Replay
201201
if is_use_PER
202202
updated_priorities = Vector{Float32}(undef, batch_size)
203-
weights = 1f0 ./ ((batch.priorities .+ 1f-10) .^ β)
203+
weights = 1.0f0 ./ ((batch.priorities .+ 1f-10) .^ β)
204204
weights ./= maximum(weights)
205205
weights = send_to_device(D, weights)
206206
end
@@ -222,8 +222,9 @@ function RLBase.update!(learner::IQNLearner, batch::NamedTuple)
222222
huber_loss ./ κ
223223
loss_per_quantile = reshape(sum(raw_loss; dims = 1), N, batch_size)
224224
loss_per_element = mean(loss_per_quantile; dims = 1) # use as priorities
225-
loss = is_use_PER ? dot(vec(weights), vec(loss_per_element)) * 1 // batch_size :
226-
mean(loss_per_element)
225+
loss =
226+
is_use_PER ? dot(vec(weights), vec(loss_per_element)) * 1//batch_size :
227+
mean(loss_per_element)
227228
ignore() do
228229
# @assert all(loss_per_element .>= 0)
229230
is_use_PER && (

src/algorithms/dqns/prioritized_dqn.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ function PrioritizedDQNLearner(;
6363
update_freq::Int = 1,
6464
target_update_freq::Int = 100,
6565
update_step::Int = 0,
66-
default_priority::Float32 = 100f0,
66+
default_priority::Float32 = 100.0f0,
6767
β_priority::Float32 = 0.5f0,
6868
rng = Random.GLOBAL_RNG,
6969
) where {Tq,Tt,Tf}
@@ -83,7 +83,7 @@ function PrioritizedDQNLearner(;
8383
default_priority,
8484
β_priority,
8585
rng,
86-
0.f0,
86+
0.0f0,
8787
)
8888
end
8989

@@ -129,7 +129,7 @@ function RLBase.update!(learner::PrioritizedDQNLearner, batch::NamedTuple)
129129
actions = CartesianIndex.(batch.actions, 1:batch_size)
130130

131131
updated_priorities = Vector{Float32}(undef, batch_size)
132-
weights = 1f0 ./ ((batch.priorities .+ 1f-10) .^ β)
132+
weights = 1.0f0 ./ ((batch.priorities .+ 1f-10) .^ β)
133133
weights ./= maximum(weights)
134134
weights = send_to_device(D, weights)
135135

@@ -146,7 +146,7 @@ function RLBase.update!(learner::PrioritizedDQNLearner, batch::NamedTuple)
146146
gs = gradient(params(Q)) do
147147
q = Q(states)[actions]
148148
batch_losses = loss_func(G, q)
149-
loss = dot(vec(weights), vec(batch_losses)) * 1 // batch_size
149+
loss = dot(vec(weights), vec(batch_losses)) * 1//batch_size
150150
ignore() do
151151
updated_priorities .= send_to_host(vec((batch_losses .+ 1f-10) .^ β))
152152
learner.loss = loss

src/algorithms/dqns/rainbow.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ function RainbowLearner(;
117117
default_priority,
118118
β_priority,
119119
rng,
120-
0.f0,
120+
0.0f0,
121121
)
122122
end
123123

@@ -175,7 +175,7 @@ function RLBase.update!(learner::RainbowLearner, batch::NamedTuple)
175175
is_use_PER = !isnothing(batch.priorities) # is use Prioritized Experience Replay
176176
if is_use_PER
177177
updated_priorities = Vector{Float32}(undef, batch_size)
178-
weights = 1f0 ./ ((batch.priorities .+ 1f-10) .^ β)
178+
weights = 1.0f0 ./ ((batch.priorities .+ 1f-10) .^ β)
179179
weights ./= maximum(weights)
180180
weights = send_to_device(D, weights)
181181
end
@@ -184,8 +184,9 @@ function RLBase.update!(learner::RainbowLearner, batch::NamedTuple)
184184
logits = reshape(Q(states), n_atoms, n_actions, :)
185185
select_logits = logits[:, actions]
186186
batch_losses = loss_func(select_logits, target_distribution)
187-
loss = is_use_PER ? dot(vec(weights), vec(batch_losses)) * 1 // batch_size :
188-
mean(batch_losses)
187+
loss =
188+
is_use_PER ? dot(vec(weights), vec(batch_losses)) * 1//batch_size :
189+
mean(batch_losses)
189190
ignore() do
190191
if is_use_PER
191192
updated_priorities .= send_to_host(vec((batch_losses .+ 1f-10) .^ β))

src/algorithms/policy_gradient/A2C.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ Base.@kwdef mutable struct A2CLearner{A<:ActorCritic} <: AbstractLearner
1717
approximator::A
1818
γ::Float32
1919
max_grad_norm::Union{Nothing,Float32} = nothing
20-
norm::Float32 = 0.f0
20+
norm::Float32 = 0.0f0
2121
actor_loss_weight::Float32
2222
critic_loss_weight::Float32
2323
entropy_loss_weight::Float32
24-
actor_loss::Float32 = 0.f0
25-
critic_loss::Float32 = 0.f0
26-
entropy_loss::Float32 = 0.f0
27-
loss::Float32 = 0.f0
24+
actor_loss::Float32 = 0.0f0
25+
critic_loss::Float32 = 0.0f0
26+
entropy_loss::Float32 = 0.0f0
27+
loss::Float32 = 0.0f0
2828
end
2929

3030
function (learner::A2CLearner)(env::MultiThreadEnv)
@@ -83,7 +83,7 @@ function RLBase.update!(learner::A2CLearner, t::AbstractTrajectory)
8383
advantage = vec(gains) .- vec(values)
8484
actor_loss = -mean(log_probs_select .* Zygote.dropgrad(advantage))
8585
critic_loss = mean(advantage .^ 2)
86-
entropy_loss = -sum(probs .* log_probs) * 1 // size(probs, 2)
86+
entropy_loss = -sum(probs .* log_probs) * 1//size(probs, 2)
8787
loss = w₁ * actor_loss + w₂ * critic_loss - w₃ * entropy_loss
8888
ignore() do
8989
learner.actor_loss = actor_loss

src/algorithms/policy_gradient/A2CGAE.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ Base.@kwdef mutable struct A2CGAELearner{A<:ActorCritic} <: AbstractLearner
1717
γ::Float32
1818
λ::Float32
1919
max_grad_norm::Union{Nothing,Float32} = nothing
20-
norm::Float32 = 0.f0
20+
norm::Float32 = 0.0f0
2121
actor_loss_weight::Float32
2222
critic_loss_weight::Float32
2323
entropy_loss_weight::Float32
24-
actor_loss::Float32 = 0.f0
25-
critic_loss::Float32 = 0.f0
26-
entropy_loss::Float32 = 0.f0
27-
loss::Float32 = 0.f0
24+
actor_loss::Float32 = 0.0f0
25+
critic_loss::Float32 = 0.0f0
26+
entropy_loss::Float32 = 0.0f0
27+
loss::Float32 = 0.0f0
2828
end
2929

3030
(learner::A2CGAELearner)(env::MultiThreadEnv) =
@@ -88,7 +88,7 @@ function RLBase.update!(learner::A2CGAELearner, t::AbstractTrajectory)
8888
advantage = vec(gains) .- vec(values)
8989
actor_loss = -mean(log_probs_select .* advantages)
9090
critic_loss = mean(advantage .^ 2)
91-
entropy_loss = -sum(probs .* log_probs) * 1 // size(probs, 2)
91+
entropy_loss = -sum(probs .* log_probs) * 1//size(probs, 2)
9292
loss = w₁ * actor_loss + w₂ * critic_loss - w₃ * entropy_loss
9393
ignore() do
9494
learner.actor_loss = actor_loss

src/algorithms/policy_gradient/ddpg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ function DDPGPolicy(;
8888
act_noise,
8989
step,
9090
rng,
91-
0.f0,
92-
0.f0,
91+
0.0f0,
92+
0.0f0,
9393
)
9494
end
9595

src/algorithms/policy_gradient/ppo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ function RLBase.update!(learner::PPOLearner, t::PPOTrajectory)
153153

154154
actor_loss = -mean(min.(surr1, surr2))
155155
critic_loss = mean((r .- v′) .^ 2)
156-
entropy_loss = -sum(p′ .* log_p′) * 1 // size(p′, 2)
156+
entropy_loss = -sum(p′ .* log_p′) * 1//size(p′, 2)
157157
loss = w₁ * actor_loss + w₂ * critic_loss - w₃ * entropy_loss
158158

159159
ignore() do

src/algorithms/policy_gradient/sac.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ function evaluate(p::SACPolicy, state)
126126
π_dist = Normal.(μ, exp.(log_σ))
127127
z = rand.(p.rng, π_dist)
128128
logp_π = sum(logpdf.(π_dist, z), dims = 1)
129-
logp_π -= sum((2f0 .* (log(2f0) .- z - softplus.(-2f0 * z))), dims = 1)
129+
logp_π -= sum((2.0f0 .* (log(2.0f0) .- z - softplus.(-2.0f0 * z))), dims = 1)
130130
return tanh.(z), logp_π
131131
end
132132

0 commit comments

Comments
 (0)