Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 74c6e5a

Browse files
authored
fix legal_actions_mask errors (#74)
1 parent bf79285 commit 74c6e5a

File tree

6 files changed

+40
-90
lines changed

6 files changed

+40
-90
lines changed

src/algorithms/dqns/dqn.jl

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -88,21 +88,15 @@ end
8888
if `!isnothing(stack_size)`.
8989
"""
9090
function (learner::DQNLearner)(env)
91-
probs =
92-
env |>
93-
get_state |>
91+
env |>
92+
get_state |>
93+
x ->
94+
Flux.unsqueeze(x, ndims(x) + 1) |>
9495
x ->
95-
Flux.unsqueeze(x, ndims(x) + 1) |>
96-
x ->
97-
send_to_device(device(learner.approximator), x) |>
98-
learner.approximator |>
99-
vec |>
100-
send_to_host
101-
102-
if ActionStyle(env) === FULL_ACTION_SET
103-
probs .+= typemin(eltype(probs)) .* (1 .- get_legal_actions_mask(env))
104-
end
105-
probs
96+
send_to_device(device(learner.approximator), x) |>
97+
learner.approximator |>
98+
vec |>
99+
send_to_host
106100
end
107101

108102
function RLBase.update!(learner::DQNLearner, t::AbstractTrajectory)
@@ -133,9 +127,9 @@ function RLBase.update!(learner::DQNLearner, t::AbstractTrajectory)
133127

134128
target_q = Qₜ(next_states)
135129
if haskey(t, :next_legal_actions_mask)
136-
target_q .+=
137-
typemin(eltype(target_q)) .*
138-
(1 .- send_to_device(D, t[:next_legal_actions_mask]))
130+
masked_value = fill(typemin(Float32), size(experience.next_legal_actions_mask))
131+
masked_value[experience.next_legal_actions_mask] .= 0
132+
target_q .+= send_to_device(D, masked_value)
139133
end
140134

141135
q′ = dropdims(maximum(target_q; dims = 1), dims = 1)

src/algorithms/dqns/iqn.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,7 @@ function (learner::IQNLearner)(env)
156156
τ = rand(learner.device_rng, Float32, learner.K, 1)
157157
τₑₘ = embed(τ, learner.Nₑₘ)
158158
quantiles = learner.approximator(state, τₑₘ)
159-
probs = vec(mean(quantiles; dims = 2)) |> send_to_host
160-
if ActionStyle(env) === FULL_ACTION_SET
161-
probs .+= typemin(eltype(probs)) .* (1 .- get_legal_actions_mask(env))
162-
end
163-
probs
159+
vec(mean(quantiles; dims = 2)) |> send_to_host
164160
end
165161

166162
embed(x, Nₑₘ) = cos.(Float32(π) .* (1:Nₑₘ) .* reshape(x, 1, :))
@@ -187,9 +183,9 @@ function RLBase.update!(learner::IQNLearner, batch::NamedTuple)
187183
avg_zₜ = mean(zₜ, dims = 2)
188184

189185
if !isnothing(batch.next_legal_actions_mask)
190-
avg_zₜ .+=
191-
typemin(eltype(avg_zₜ)) .*
192-
(1 .- send_to_device(D, batch.next_legal_actions_mask))
186+
masked_value = fill(typemin(Float32), size(batch.next_legal_actions_mask))
187+
masked_value[batch.next_legal_actions_mask] .= 0
188+
avg_zₜ .+= send_to_device(D, masked_value)
193189
end
194190

195191
aₜ = argmax(avg_zₜ, dims = 1)

src/algorithms/dqns/prioritized_dqn.jl

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -102,21 +102,15 @@ end
102102
if `!isnothing(stack_size)`.
103103
"""
104104
function (learner::PrioritizedDQNLearner)(env)
105-
probs =
106-
env |>
107-
get_state |>
105+
env |>
106+
get_state |>
107+
x ->
108+
Flux.unsqueeze(x, ndims(x) + 1) |>
108109
x ->
109-
Flux.unsqueeze(x, ndims(x) + 1) |>
110-
x ->
111-
send_to_device(device(learner.approximator), x) |>
112-
learner.approximator |>
113-
vec |>
114-
send_to_host
115-
116-
if ActionStyle(env) === FULL_ACTION_SET
117-
probs .+= typemin(eltype(probs)) .* (1 .- get_legal_actions_mask(env))
118-
end
119-
probs
110+
send_to_device(device(learner.approximator), x) |>
111+
learner.approximator |>
112+
vec |>
113+
send_to_host
120114
end
121115

122116
function RLBase.update!(learner::PrioritizedDQNLearner, batch::NamedTuple)
@@ -141,9 +135,9 @@ function RLBase.update!(learner::PrioritizedDQNLearner, batch::NamedTuple)
141135

142136
target_q = Qₜ(next_states)
143137
if !isnothing(batch.next_legal_actions_mask)
144-
target_q .+=
145-
typemin(eltype(target_q)) .*
146-
(1 .- send_to_device(D, batch.next_legal_actions_mask))
138+
masked_value = fill(typemin(Float32), size(batch.next_legal_actions_mask))
139+
masked_value[batch.next_legal_actions_mask] .= 0
140+
target_q .+= send_to_device(D, masked_value)
147141
end
148142

149143
q′ = dropdims(maximum(target_q; dims = 1), dims = 1)

src/algorithms/dqns/rainbow.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,7 @@ function (learner::RainbowLearner)(env)
126126
state = Flux.unsqueeze(state, ndims(state) + 1)
127127
logits = learner.approximator(state)
128128
q = learner.support .* softmax(reshape(logits, :, learner.n_actions))
129-
probs = vec(sum(q, dims = 1)) |> send_to_host
130-
if ActionStyle(env) === FULL_ACTION_SET
131-
probs .+= typemin(eltype(probs)) .* (1 .- get_legal_actions_mask(env))
132-
end
133-
probs
129+
vec(sum(q, dims = 1)) |> send_to_host
134130
end
135131

136132
function RLBase.update!(learner::RainbowLearner, batch::NamedTuple)
@@ -161,9 +157,9 @@ function RLBase.update!(learner::RainbowLearner, batch::NamedTuple)
161157
next_probs = reshape(softmax(reshape(next_logits, n_atoms, :)), n_atoms, n_actions, :)
162158
next_q = reshape(sum(support .* next_probs, dims = 1), n_actions, :)
163159
if !isnothing(batch.next_legal_actions_mask)
164-
next_q .+=
165-
typemin(eltype(next_q)) .*
166-
(1 .- send_to_device(D, batch.next_legal_actions_mask))
160+
masked_value = fill(typemin(Float32), size(batch.next_legal_actions_mask))
161+
masked_value[batch.next_legal_actions_mask] .= 0
162+
next_q .+= send_to_device(D, masked_value)
167163
end
168164
next_prob_select = select_best_probs(next_probs, next_q)
169165

src/algorithms/policy_gradient/A2C.jl

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,28 +28,17 @@ Base.@kwdef mutable struct A2CLearner{A<:ActorCritic} <: AbstractLearner
2828
end
2929

3030
function (learner::A2CLearner)(env::MultiThreadEnv)
31-
logits =
32-
learner.approximator.actor(send_to_device(
33-
device(learner.approximator),
34-
get_state(env),
35-
)) |> send_to_host
36-
37-
if ActionStyle(env[1]) === FULL_ACTION_SET
38-
logits .+= typemin(eltype(logits)) .* (1 .- get_legal_actions_mask(env))
39-
end
40-
logits
31+
learner.approximator.actor(send_to_device(
32+
device(learner.approximator),
33+
get_state(env),
34+
)) |> send_to_host
4135
end
4236

4337
function (learner::A2CLearner)(env)
4438
s = get_state(env)
4539
s = Flux.unsqueeze(s, ndims(s) + 1)
4640
s = send_to_device(device(learner.approximator), s)
47-
logits = learner.approximator.actor(s) |> vec |> send_to_host
48-
49-
if ActionStyle(env) === FULL_ACTION_SET
50-
logits .+= typemin(eltype(logits)) .* (1 .- get_legal_actions_mask(env))
51-
end
52-
logits
41+
learner.approximator.actor(s) |> vec |> send_to_host
5342
end
5443

5544
function RLBase.update!(learner::A2CLearner, t::AbstractTrajectory)
@@ -87,11 +76,6 @@ function RLBase.update!(learner::A2CLearner, t::AbstractTrajectory)
8776
ps = Flux.params(AC)
8877
gs = gradient(ps) do
8978
logits = AC.actor(states_flattened)
90-
if haskey(t, :legal_actions_mask)
91-
logits .+=
92-
typemin(eltype(logits)) .*
93-
(1 .- flatten_batch(send_to_device(D, t[:legal_actions_mask])))
94-
end
9579
probs = softmax(logits)
9680
log_probs = logsoftmax(logits)
9781
log_probs_select = log_probs[actions]

src/algorithms/policy_gradient/ppo.jl

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -75,28 +75,17 @@ function PPOLearner(;
7575
end
7676

7777
function (learner::PPOLearner)(env::MultiThreadEnv)
78-
logits =
79-
learner.approximator.actor(send_to_device(
80-
device(learner.approximator),
81-
get_state(env),
82-
)) |> send_to_host
83-
84-
if ActionStyle(env[1]) === FULL_ACTION_SET
85-
logits .+= typemin(eltype(logits)) .* (1 .- get_legal_actions_mask(env))
86-
end
87-
logits
78+
learner.approximator.actor(send_to_device(
79+
device(learner.approximator),
80+
get_state(env),
81+
)) |> send_to_host
8882
end
8983

9084
function (learner::PPOLearner)(env)
9185
s = get_state(env)
9286
s = Flux.unsqueeze(s, ndims(s) + 1)
9387
s = send_to_device(device(learner.approximator), s)
94-
logits = learner.approximator.actor(s) |> vec |> send_to_host
95-
96-
if ActionStyle(env) === FULL_ACTION_SET
97-
logits .+= typemin(eltype(logits)) .* (1 .- get_legal_actions_mask(env))
98-
end
99-
logits
88+
learner.approximator.actor(s) |> vec |> send_to_host
10089
end
10190

10291
function RLBase.update!(learner::PPOLearner, t::PPOTrajectory)
@@ -154,9 +143,6 @@ function RLBase.update!(learner::PPOLearner, t::PPOTrajectory)
154143
gs = gradient(ps) do
155144
v′ = AC.critic(s) |> vec
156145
logit′ = AC.actor(s)
157-
if haskey(t, :legal_actions_mask)
158-
logit′ .+= typemin(eltype(logit′)) .* (1 .- lam)
159-
end
160146
p′ = softmax(logit′)
161147
log_p′ = logsoftmax(logit′)
162148
log_p′ₐ = log_p′[CartesianIndex.(a, 1:length(a))]

0 commit comments

Comments
 (0)