Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 6409d3a

Browse files
Format .jl files (#70)
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent e3f9375 commit 6409d3a

File tree

7 files changed

+55
-39
lines changed

7 files changed

+55
-39
lines changed

src/algorithms/dqns/common.jl

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,8 @@ function RLBase.update!(p::QBasedPolicy{<:PERLearners}, t::AbstractTrajectory)
9191
end
9292
end
9393

94-
function (agent::Agent{<:QBasedPolicy{<:PERLearners}})(
95-
::RLCore.Training{PostActStage},
96-
env,
97-
)
98-
push!(
99-
agent.trajectory;
100-
reward = get_reward(env),
101-
terminal = get_terminal(env),
102-
)
94+
function (agent::Agent{<:QBasedPolicy{<:PERLearners}})(::RLCore.Training{PostActStage}, env)
95+
push!(agent.trajectory; reward = get_reward(env), terminal = get_terminal(env))
10396
if haskey(agent.trajectory, :priority)
10497
push!(agent.trajectory; priority = agent.policy.learner.default_priority)
10598
end

src/algorithms/dqns/dqn.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,16 @@ end
8888
if `!isnothing(stack_size)`.
8989
"""
9090
function (learner::DQNLearner)(env)
91-
probs = env |>
92-
get_state |>
93-
x -> Flux.unsqueeze(x, ndims(x) + 1) |>
94-
x -> send_to_device(device(learner.approximator), x) |>
95-
learner.approximator |>
96-
vec |>
97-
send_to_host
91+
probs =
92+
env |>
93+
get_state |>
94+
x ->
95+
Flux.unsqueeze(x, ndims(x) + 1) |>
96+
x ->
97+
send_to_device(device(learner.approximator), x) |>
98+
learner.approximator |>
99+
vec |>
100+
send_to_host
98101

99102
if ActionStyle(env) === FULL_ACTION_SET
100103
probs .+= typemin(eltype(probs)) .* (1 .- get_legal_actions_mask(env))
@@ -130,7 +133,9 @@ function RLBase.update!(learner::DQNLearner, t::AbstractTrajectory)
130133

131134
target_q = Qₜ(next_states)
132135
if haskey(t, :next_legal_actions_mask)
133-
target_q .+= typemin(eltype(target_q)) .* (1 .- send_to_device(D, t[:next_legal_actions_mask]))
136+
target_q .+=
137+
typemin(eltype(target_q)) .*
138+
(1 .- send_to_device(D, t[:next_legal_actions_mask]))
134139
end
135140

136141
q′ = dropdims(maximum(target_q; dims = 1), dims = 1)

src/algorithms/dqns/iqn.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,9 @@ function RLBase.update!(learner::IQNLearner, batch::NamedTuple)
187187
avg_zₜ = mean(zₜ, dims = 2)
188188

189189
if !isnothing(batch.next_legal_actions_mask)
190-
avg_zₜ .+= typemin(eltype(avg_zₜ)) .* (1 .- send_to_device(D, batch.next_legal_actions_mask))
190+
avg_zₜ .+=
191+
typemin(eltype(avg_zₜ)) .*
192+
(1 .- send_to_device(D, batch.next_legal_actions_mask))
191193
end
192194

193195
aₜ = argmax(avg_zₜ, dims = 1)
@@ -224,7 +226,8 @@ function RLBase.update!(learner::IQNLearner, batch::NamedTuple)
224226
huber_loss ./ κ
225227
loss_per_quantile = reshape(sum(raw_loss; dims = 1), N, batch_size)
226228
loss_per_element = mean(loss_per_quantile; dims = 1) # use as priorities
227-
loss = is_use_PER ? dot(vec(weights), vec(loss_per_element)) * 1 // batch_size : mean(loss_per_element)
229+
loss = is_use_PER ? dot(vec(weights), vec(loss_per_element)) * 1 // batch_size :
230+
mean(loss_per_element)
228231
ignore() do
229232
# @assert all(loss_per_element .>= 0)
230233
is_use_PER && (

src/algorithms/dqns/prioritized_dqn.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,16 @@ end
102102
if `!isnothing(stack_size)`.
103103
"""
104104
function (learner::PrioritizedDQNLearner)(env)
105-
probs = env |>
106-
get_state |>
107-
x -> Flux.unsqueeze(x, ndims(x) + 1) |>
108-
x -> send_to_device(device(learner.approximator), x) |>
109-
learner.approximator |>
110-
vec |>
111-
send_to_host
105+
probs =
106+
env |>
107+
get_state |>
108+
x ->
109+
Flux.unsqueeze(x, ndims(x) + 1) |>
110+
x ->
111+
send_to_device(device(learner.approximator), x) |>
112+
learner.approximator |>
113+
vec |>
114+
send_to_host
112115

113116
if ActionStyle(env) === FULL_ACTION_SET
114117
probs .+= typemin(eltype(probs)) .* (1 .- get_legal_actions_mask(env))
@@ -138,7 +141,9 @@ function RLBase.update!(learner::PrioritizedDQNLearner, batch::NamedTuple)
138141

139142
target_q = Qₜ(next_states)
140143
if !isnothing(batch.next_legal_actions_mask)
141-
target_q .+= typemin(eltype(target_q)) .* (1 .- send_to_device(D, batch.next_legal_actions_mask))
144+
target_q .+=
145+
typemin(eltype(target_q)) .*
146+
(1 .- send_to_device(D, batch.next_legal_actions_mask))
142147
end
143148

144149
q′ = dropdims(maximum(target_q; dims = 1), dims = 1)

src/algorithms/dqns/rainbow.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,9 @@ function RLBase.update!(learner::RainbowLearner, batch::NamedTuple)
161161
next_probs = reshape(softmax(reshape(next_logits, n_atoms, :)), n_atoms, n_actions, :)
162162
next_q = reshape(sum(support .* next_probs, dims = 1), n_actions, :)
163163
if !isnothing(batch.next_legal_actions_mask)
164-
next_q .+= typemin(eltype(next_q)) .* (1 .- send_to_device(D, batch.next_legal_actions_mask))
164+
next_q .+=
165+
typemin(eltype(next_q)) .*
166+
(1 .- send_to_device(D, batch.next_legal_actions_mask))
165167
end
166168
next_prob_select = select_best_probs(next_probs, next_q)
167169

@@ -186,7 +188,8 @@ function RLBase.update!(learner::RainbowLearner, batch::NamedTuple)
186188
logits = reshape(Q(states), n_atoms, n_actions, :)
187189
select_logits = logits[:, actions]
188190
batch_losses = loss_func(select_logits, target_distribution)
189-
loss = is_use_PER ? dot(vec(weights), vec(batch_losses)) * 1 // batch_size : mean(batch_losses)
191+
loss = is_use_PER ? dot(vec(weights), vec(batch_losses)) * 1 // batch_size :
192+
mean(batch_losses)
190193
ignore() do
191194
if is_use_PER
192195
updated_priorities .= send_to_host(vec((batch_losses .+ 1f-10) .^ β))

src/algorithms/policy_gradient/A2C.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@ Base.@kwdef mutable struct A2CLearner{A<:ActorCritic} <: AbstractLearner
2828
end
2929

3030
function (learner::A2CLearner)(env::MultiThreadEnv)
31-
logits = learner.approximator.actor(send_to_device(
32-
device(learner.approximator),
33-
get_state(env),
34-
)) |> send_to_host
31+
logits =
32+
learner.approximator.actor(send_to_device(
33+
device(learner.approximator),
34+
get_state(env),
35+
)) |> send_to_host
3536

3637
if ActionStyle(env[1]) === FULL_ACTION_SET
3738
logits .+= typemin(eltype(logits)) .* (1 .- get_legal_actions_mask(env))
@@ -87,7 +88,9 @@ function RLBase.update!(learner::A2CLearner, t::AbstractTrajectory)
8788
gs = gradient(ps) do
8889
logits = AC.actor(states_flattened)
8990
if haskey(t, :legal_actions_mask)
90-
logits .+= typemin(eltype(logits)) .* (1 .- flatten_batch(send_to_device(D, t[:legal_actions_mask])))
91+
logits .+=
92+
typemin(eltype(logits)) .*
93+
(1 .- flatten_batch(send_to_device(D, t[:legal_actions_mask])))
9194
end
9295
probs = softmax(logits)
9396
log_probs = logsoftmax(logits)

src/algorithms/policy_gradient/ppo.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,11 @@ function PPOLearner(;
7575
end
7676

7777
function (learner::PPOLearner)(env::MultiThreadEnv)
78-
logits = learner.approximator.actor(send_to_device(
79-
device(learner.approximator),
80-
get_state(env),
81-
)) |> send_to_host
78+
logits =
79+
learner.approximator.actor(send_to_device(
80+
device(learner.approximator),
81+
get_state(env),
82+
)) |> send_to_host
8283

8384
if ActionStyle(env[1]) === FULL_ACTION_SET
8485
logits .+= typemin(eltype(logits)) .* (1 .- get_legal_actions_mask(env))
@@ -139,7 +140,10 @@ function RLBase.update!(learner::PPOLearner, t::PPOTrajectory)
139140
inds = rand_inds[(i-1)*microbatch_size+1:i*microbatch_size]
140141
s = send_to_device(D, select_last_dim(states_flatten, inds))
141142
if haskey(t, :legal_actions_mask)
142-
lam = send_to_device(D, select_last_dim(flatten_batch(t[:legal_actions_mask]), inds))
143+
lam = send_to_device(
144+
D,
145+
select_last_dim(flatten_batch(t[:legal_actions_mask]), inds),
146+
)
143147
end
144148
a = vec(actions)[inds]
145149
r = send_to_device(D, vec(returns)[inds])

0 commit comments

Comments
 (0)