Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit cf43523

Browse files
Format .jl files (#91)
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent c2f09f1 commit cf43523

File tree

12 files changed

+63
-63
lines changed

12 files changed

+63
-63
lines changed

src/algorithms/cfr/external_sampling_mccfr.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@ end
2020

2121
(p::ExternalSamplingMCCFRPolicy)(env::AbstractEnv) = p.behavior_policy(env)
2222

23-
RLBase.get_prob(p::ExternalSamplingMCCFRPolicy, env::AbstractEnv) = get_prob(p.behavior_policy, env)
23+
RLBase.get_prob(p::ExternalSamplingMCCFRPolicy, env::AbstractEnv) =
24+
get_prob(p.behavior_policy, env)
2425

2526
function ExternalSamplingMCCFRPolicy(;
2627
env::AbstractEnv,
2728
n_iter::Int,
28-
rng=Random.GLOBAL_RNG,
29+
rng = Random.GLOBAL_RNG,
2930
)
3031
@assert NumAgentStyle(env) isa MultiAgent
3132
@assert DynamicStyle(env) === SEQUENTIAL
@@ -90,4 +91,4 @@ function external_sampling(env, i, nodes, rng)
9091
u
9192
end
9293
end
93-
end
94+
end

src/algorithms/cfr/outcome_sampling_mccfr.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@ end
2020

2121
(p::OutcomeSamplingMCCFRPolicy)(env::AbstractEnv) = p.behavior_policy(env)
2222

23-
RLBase.get_prob(p::OutcomeSamplingMCCFRPolicy, env::AbstractEnv) = get_prob(p.behavior_policy, env)
23+
RLBase.get_prob(p::OutcomeSamplingMCCFRPolicy, env::AbstractEnv) =
24+
get_prob(p.behavior_policy, env)
2425

2526
function OutcomeSamplingMCCFRPolicy(;
2627
env::AbstractEnv,
2728
n_iter::Int,
28-
rng=Random.GLOBAL_RNG,
29-
ϵ=0.6
29+
rng = Random.GLOBAL_RNG,
30+
ϵ = 0.6,
3031
)
3132
@assert NumAgentStyle(env) isa MultiAgent
3233
@assert DynamicStyle(env) === SEQUENTIAL
@@ -91,9 +92,9 @@ function outcome_sampling(env, i, nodes, ϵ, πᵢ, π₋ᵢ, s, rng)
9192
w = u * π₋ᵢ
9293
rI .+= w * πₜₐᵢₗ .* ((1:n .== aᵢ) .- σ[aᵢ])
9394
else
94-
sI .+= π₋ᵢ / s .* σ
95+
sI .+= π₋ᵢ / s .* σ
9596
end
9697

9798
u, πₜₐᵢₗ * σ[aᵢ]
9899
end
99-
end
100+
end

src/algorithms/cfr/tabular_cfr.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ function cfr!(nodes, env, player, reach_probs, chance_player_reach_prob, ratio)
144144
end
145145
end
146146

147-
regret_matching!(node::InfoStateNode) = regret_matching!(node.strategy, node.cumulative_regret)
147+
regret_matching!(node::InfoStateNode) =
148+
regret_matching!(node.strategy, node.cumulative_regret)
148149

149150
function regret_matching!(strategy, cumulative_regret)
150151
s = mapreduce(x -> max(0, x), +, cumulative_regret)

src/algorithms/dqns/iqn.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ function RLBase.update!(learner::IQNLearner, batch::NamedTuple)
223223
loss_per_quantile = reshape(sum(raw_loss; dims = 1), N, batch_size)
224224
loss_per_element = mean(loss_per_quantile; dims = 1) # use as priorities
225225
loss =
226-
is_use_PER ? dot(vec(weights), vec(loss_per_element)) * 1//batch_size :
226+
is_use_PER ? dot(vec(weights), vec(loss_per_element)) * 1 // batch_size :
227227
mean(loss_per_element)
228228
ignore() do
229229
# @assert all(loss_per_element .>= 0)

src/algorithms/dqns/prioritized_dqn.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ function RLBase.update!(learner::PrioritizedDQNLearner, batch::NamedTuple)
146146
gs = gradient(params(Q)) do
147147
q = Q(states)[actions]
148148
batch_losses = loss_func(G, q)
149-
loss = dot(vec(weights), vec(batch_losses)) * 1//batch_size
149+
loss = dot(vec(weights), vec(batch_losses)) * 1 // batch_size
150150
ignore() do
151151
updated_priorities .= send_to_host(vec((batch_losses .+ 1f-10) .^ β))
152152
learner.loss = loss

src/algorithms/dqns/rainbow.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ function RLBase.update!(learner::RainbowLearner, batch::NamedTuple)
185185
select_logits = logits[:, actions]
186186
batch_losses = loss_func(select_logits, target_distribution)
187187
loss =
188-
is_use_PER ? dot(vec(weights), vec(batch_losses)) * 1//batch_size :
188+
is_use_PER ? dot(vec(weights), vec(batch_losses)) * 1 // batch_size :
189189
mean(batch_losses)
190190
ignore() do
191191
if is_use_PER

src/algorithms/policy_gradient/A2C.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ function RLBase.update!(learner::A2CLearner, t::AbstractTrajectory)
8383
advantage = vec(gains) .- vec(values)
8484
actor_loss = -mean(log_probs_select .* Zygote.dropgrad(advantage))
8585
critic_loss = mean(advantage .^ 2)
86-
entropy_loss = -sum(probs .* log_probs) * 1//size(probs, 2)
86+
entropy_loss = -sum(probs .* log_probs) * 1 // size(probs, 2)
8787
loss = w₁ * actor_loss + w₂ * critic_loss - w₃ * entropy_loss
8888
ignore() do
8989
learner.actor_loss = actor_loss

src/algorithms/policy_gradient/A2CGAE.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ function RLBase.update!(learner::A2CGAELearner, t::AbstractTrajectory)
8888
advantage = vec(gains) .- vec(values)
8989
actor_loss = -mean(log_probs_select .* advantages)
9090
critic_loss = mean(advantage .^ 2)
91-
entropy_loss = -sum(probs .* log_probs) * 1//size(probs, 2)
91+
entropy_loss = -sum(probs .* log_probs) * 1 // size(probs, 2)
9292
loss = w₁ * actor_loss + w₂ * critic_loss - w₃ * entropy_loss
9393
ignore() do
9494
learner.actor_loss = actor_loss

src/algorithms/policy_gradient/policy_gradient.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
include("vpg.jl")
1+
include("vpg.jl")
22
include("A2C.jl")
33
include("ppo.jl")
44
include("A2CGAE.jl")

src/algorithms/policy_gradient/ppo.jl

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -81,19 +81,21 @@ function PPOPolicy(;
8181
)
8282
end
8383

84-
function RLBase.get_prob(p::PPOPolicy{<:ActorCritic{<:NeuralNetworkApproximator{<:GaussianNetwork}}, Normal}, state::AbstractArray)
85-
p.approximator.actor(send_to_device(
86-
device(p.approximator),
87-
state,
88-
)) |> send_to_host |> StructArray{Normal}
84+
function RLBase.get_prob(
85+
p::PPOPolicy{<:ActorCritic{<:NeuralNetworkApproximator{<:GaussianNetwork}},Normal},
86+
state::AbstractArray,
87+
)
88+
p.approximator.actor(send_to_device(device(p.approximator), state)) |>
89+
send_to_host |>
90+
StructArray{Normal}
8991
end
9092

91-
function RLBase.get_prob(p::PPOPolicy{<:ActorCritic, Categorical}, state::AbstractArray)
92-
logits = p.approximator.actor(send_to_device(
93-
device(p.approximator),
94-
state,
95-
)) |> softmax |> send_to_host
96-
[Categorical(x;check_args=false) for x in eachcol(logits)]
93+
function RLBase.get_prob(p::PPOPolicy{<:ActorCritic,Categorical}, state::AbstractArray)
94+
logits =
95+
p.approximator.actor(send_to_device(device(p.approximator), state)) |>
96+
softmax |>
97+
send_to_host
98+
[Categorical(x; check_args = false) for x in eachcol(logits)]
9799
end
98100

99101
RLBase.get_prob(p::PPOPolicy, env::MultiThreadEnv) = get_prob(p, get_state(env))
@@ -164,14 +166,14 @@ function RLBase.update!(p::PPOPolicy, t::PPOTrajectory)
164166
if AC.actor isa NeuralNetworkApproximator{<:GaussianNetwork}
165167
μ, σ = AC.actor(s)
166168
log_p′ₐ = normlogpdf(μ, σ, a)
167-
entropy_loss = mean((log(2.0f0π)+1)/2 .+ log.(σ))
169+
entropy_loss = mean((log(2.0f0π) + 1) / 2 .+ log.(σ))
168170
else
169171
# actor is assumed to return discrete logits
170172
logit′ = AC.actor(s)
171173
p′ = softmax(logit′)
172174
log_p′ = logsoftmax(logit′)
173175
log_p′ₐ = log_p′[CartesianIndex.(a, 1:length(a))]
174-
entropy_loss = -sum(p′ .* log_p′) * 1//size(p′, 2)
176+
entropy_loss = -sum(p′ .* log_p′) * 1 // size(p′, 2)
175177
end
176178

177179
ratio = exp.(log_p′ₐ .- log_p)
@@ -198,15 +200,18 @@ function RLBase.update!(p::PPOPolicy, t::PPOTrajectory)
198200
end
199201
end
200202

201-
function (agent::Agent{<:Union{PPOPolicy, RandomStartPolicy{<:PPOPolicy}}})(::Training{PreActStage}, env::MultiThreadEnv)
203+
function (agent::Agent{<:Union{PPOPolicy,RandomStartPolicy{<:PPOPolicy}}})(
204+
::Training{PreActStage},
205+
env::MultiThreadEnv,
206+
)
202207
state = get_state(env)
203208
dist = get_prob(agent.policy, env)
204209

205210
# currently RandomPolicy returns a Matrix instead of a (vector of) distribution.
206211
if dist isa Matrix{<:Number}
207-
dist = [Categorical(x;check_args=false) for x in eachcol(dist)]
212+
dist = [Categorical(x; check_args = false) for x in eachcol(dist)]
208213
elseif dist isa Vector{<:Vector{<:Number}}
209-
dist = [Categorical(x;check_args=false) for x in dist]
214+
dist = [Categorical(x; check_args = false) for x in dist]
210215
end
211216

212217
# !!! a little ugly

0 commit comments

Comments
 (0)