Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit e3f9375

Browse files
authored
support legal_actions_mask (#69)
1 parent 1f0cc22 commit e3f9375

File tree

7 files changed

+166
-78
lines changed

7 files changed

+166
-78
lines changed

src/algorithms/dqns/common.jl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function extract_experience(t::AbstractTrajectory, learner::PERLearners)
1313
# 1. sample indices based on priority
1414
valid_ind_range =
1515
isnothing(s) ? (1:(length(t[:terminal])-h)) : (s:(length(t[:terminal])-h))
16-
if t isa CircularCompactPSARTSATrajectory
16+
if haskey(t, :priority)
1717
inds = Vector{Int}(undef, n)
1818
priorities = Vector{Float32}(undef, n)
1919
for i in 1:n
@@ -29,10 +29,21 @@ function extract_experience(t::AbstractTrajectory, learner::PERLearners)
2929
priorities = nothing
3030
end
3131

32+
next_inds = inds .+ h
33+
3234
# 2. extract SARTS
3335
states = consecutive_view(t[:state], inds; n_stack = s)
3436
actions = consecutive_view(t[:action], inds)
35-
next_states = consecutive_view(t[:state], inds .+ h; n_stack = s)
37+
next_states = consecutive_view(t[:state], next_inds; n_stack = s)
38+
39+
if haskey(t, :legal_actions_mask)
40+
legal_actions_mask = consecutive_view(t[:legal_actions_mask], inds)
41+
next_legal_actions_mask = consecutive_view(t[:next_legal_actions_mask], inds)
42+
else
43+
legal_actions_mask = nothing
44+
next_legal_actions_mask = nothing
45+
end
46+
3647
consecutive_rewards = consecutive_view(t[:reward], inds; n_horizon = h)
3748
consecutive_terminals = consecutive_view(t[:terminal], inds; n_horizon = h)
3849
rewards, terminals = zeros(Float32, n), fill(false, n)
@@ -48,10 +59,12 @@ function extract_experience(t::AbstractTrajectory, learner::PERLearners)
4859
inds,
4960
(
5061
states = states,
62+
legal_actions_mask = legal_actions_mask,
5163
actions = actions,
5264
rewards = rewards,
5365
terminals = terminals,
5466
next_states = next_states,
67+
next_legal_actions_mask = next_legal_actions_mask,
5568
priorities = priorities,
5669
)
5770
end
@@ -70,23 +83,25 @@ function RLBase.update!(p::QBasedPolicy{<:PERLearners}, t::AbstractTrajectory)
7083

7184
inds, experience = extract_experience(t, p.learner)
7285

73-
if t isa CircularCompactPSARTSATrajectory
86+
if haskey(t, :priority)
7487
priorities = update!(p.learner, experience)
7588
t[:priority][inds] .= priorities
7689
else
7790
update!(p.learner, experience)
7891
end
7992
end
8093

81-
function (agent::Agent{<:QBasedPolicy{<:PERLearners},<:CircularCompactPSARTSATrajectory})(
94+
function (agent::Agent{<:QBasedPolicy{<:PERLearners}})(
8295
::RLCore.Training{PostActStage},
8396
env,
8497
)
8598
push!(
8699
agent.trajectory;
87100
reward = get_reward(env),
88101
terminal = get_terminal(env),
89-
priority = agent.policy.learner.default_priority,
90102
)
103+
if haskey(agent.trajectory, :priority)
104+
push!(agent.trajectory; priority = agent.policy.learner.default_priority)
105+
end
91106
nothing
92107
end

src/algorithms/dqns/dqn.jl

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,20 @@ end
8787
The state of the observation is assumed to have been stacked,
8888
if `!isnothing(stack_size)`.
8989
"""
90-
(learner::DQNLearner)(env) =
91-
env |>
90+
function (learner::DQNLearner)(env)
91+
probs = env |>
9292
get_state |>
93-
x ->
94-
Flux.unsqueeze(x, ndims(x) + 1) |>
95-
x ->
96-
send_to_device(device(learner.approximator), x) |>
97-
learner.approximator |>
98-
send_to_host |>
99-
Flux.squeezebatch
93+
x -> Flux.unsqueeze(x, ndims(x) + 1) |>
94+
x -> send_to_device(device(learner.approximator), x) |>
95+
learner.approximator |>
96+
vec |>
97+
send_to_host
98+
99+
if ActionStyle(env) === FULL_ACTION_SET
100+
probs .+= typemin(eltype(probs)) .* (1 .- get_legal_actions_mask(env))
101+
end
102+
probs
103+
end
100104

101105
function RLBase.update!(learner::DQNLearner, t::AbstractTrajectory)
102106
length(t[:terminal]) < learner.min_replay_history && return
@@ -124,10 +128,16 @@ function RLBase.update!(learner::DQNLearner, t::AbstractTrajectory)
124128
terminals = send_to_device(D, experience.terminals)
125129
next_states = send_to_device(D, experience.next_states)
126130

131+
target_q = Qₜ(next_states)
132+
if haskey(t, :next_legal_actions_mask)
133+
target_q .+= typemin(eltype(target_q)) .* (1 .- send_to_device(D, t[:next_legal_actions_mask]))
134+
end
135+
136+
q′ = dropdims(maximum(target_q; dims = 1), dims = 1)
137+
G = rewards .+ γ^update_horizon .* (1 .- terminals) .* q′
138+
127139
gs = gradient(params(Q)) do
128140
q = Q(states)[actions]
129-
q′ = dropdims(maximum(Qₜ(next_states); dims = 1), dims = 1)
130-
G = rewards .+ γ^update_horizon .* (1 .- terminals) .* q′
131141
loss = loss_func(G, q)
132142
ignore() do
133143
learner.loss = loss
@@ -147,9 +157,20 @@ function extract_experience(t::AbstractTrajectory, learner::DQNLearner)
147157
valid_ind_range =
148158
isnothing(s) ? (1:(length(t[:terminal])-h)) : (s:(length(t[:terminal])-h))
149159
inds = rand(learner.rng, valid_ind_range, n)
160+
next_inds = inds .+ h
161+
150162
states = consecutive_view(t[:state], inds; n_stack = s)
151163
actions = consecutive_view(t[:action], inds)
152-
next_states = consecutive_view(t[:state], inds .+ h; n_stack = s)
164+
next_states = consecutive_view(t[:state], next_inds; n_stack = s)
165+
166+
if haskey(t, :legal_actions_mask)
167+
legal_actions_mask = consecutive_view(t[:legal_actions_mask], inds)
168+
next_legal_actions_mask = consecutive_view(t[:next_legal_actions_mask], next_inds)
169+
else
170+
legal_actions_mask = nothing
171+
next_legal_actions_mask = nothing
172+
end
173+
153174
consecutive_rewards = consecutive_view(t[:reward], inds; n_horizon = h)
154175
consecutive_terminals = consecutive_view(t[:terminal], inds; n_horizon = h)
155176
rewards, terminals = zeros(Float32, n), fill(false, n)
@@ -167,9 +188,11 @@ function extract_experience(t::AbstractTrajectory, learner::DQNLearner)
167188
end
168189
(
169190
states = states,
191+
legal_actions_mask = legal_actions_mask,
170192
actions = actions,
171193
rewards = rewards,
172194
terminals = terminals,
173195
next_states = next_states,
196+
next_legal_actions_mask = next_legal_actions_mask,
174197
)
175198
end

src/algorithms/dqns/iqn.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,11 @@ function (learner::IQNLearner)(env)
156156
τ = rand(learner.device_rng, Float32, learner.K, 1)
157157
τₑₘ = embed(τ, learner.Nₑₘ)
158158
quantiles = learner.approximator(state, τₑₘ)
159-
vec(mean(quantiles; dims = 2)) |> send_to_host
159+
probs = vec(mean(quantiles; dims = 2)) |> send_to_host
160+
if ActionStyle(env) === FULL_ACTION_SET
161+
probs .+= typemin(eltype(probs)) .* (1 .- get_legal_actions_mask(env))
162+
end
163+
probs
160164
end
161165

162166
embed(x, Nₑₘ) = cos.(Float32(π) .* (1:Nₑₘ) .* reshape(x, 1, :))
@@ -180,7 +184,13 @@ function RLBase.update!(learner::IQNLearner, batch::NamedTuple)
180184
τ′ = rand(learner.device_rng, Float32, N′, batch_size) # TODO: support β distribution
181185
τₑₘ′ = embed(τ′, Nₑₘ)
182186
zₜ = Zₜ(s′, τₑₘ′)
183-
aₜ = argmax(mean(zₜ, dims = 2), dims = 1)
187+
avg_zₜ = mean(zₜ, dims = 2)
188+
189+
if !isnothing(batch.next_legal_actions_mask)
190+
avg_zₜ .+= typemin(eltype(avg_zₜ)) .* (1 .- send_to_device(D, batch.next_legal_actions_mask))
191+
end
192+
193+
aₜ = argmax(avg_zₜ, dims = 1)
184194
aₜ = aₜ .+ typeof(aₜ)(CartesianIndices((0, 0:N′-1, 0)))
185195
qₜ = reshape(zₜ[aₜ], :, batch_size)
186196
target = reshape(r, 1, batch_size) .+ learner.γ * reshape(1 .- t, 1, batch_size) .* qₜ # reshape to allow broadcast
@@ -214,8 +224,7 @@ function RLBase.update!(learner::IQNLearner, batch::NamedTuple)
214224
huber_loss ./ κ
215225
loss_per_quantile = reshape(sum(raw_loss; dims = 1), N, batch_size)
216226
loss_per_element = mean(loss_per_quantile; dims = 1) # use as priorities
217-
loss = is_use_PER ? dot(vec(weights), vec(loss_per_element)) * 1 // batch_size :
218-
mean(loss_per_element)
227+
loss = is_use_PER ? dot(vec(weights), vec(loss_per_element)) * 1 // batch_size : mean(loss_per_element)
219228
ignore() do
220229
# @assert all(loss_per_element .>= 0)
221230
is_use_PER && (

src/algorithms/dqns/prioritized_dqn.jl

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,20 @@ end
101101
The state of the observation is assumed to have been stacked,
102102
if `!isnothing(stack_size)`.
103103
"""
104-
(learner::PrioritizedDQNLearner)(env) =
105-
env |>
104+
function (learner::PrioritizedDQNLearner)(env)
105+
probs = env |>
106106
get_state |>
107-
x ->
108-
Flux.unsqueeze(x, ndims(x) + 1) |>
109-
x ->
110-
send_to_device(device(learner.approximator), x) |>
111-
learner.approximator |>
112-
send_to_host |>
113-
Flux.squeezebatch
107+
x -> Flux.unsqueeze(x, ndims(x) + 1) |>
108+
x -> send_to_device(device(learner.approximator), x) |>
109+
learner.approximator |>
110+
vec |>
111+
send_to_host
112+
113+
if ActionStyle(env) === FULL_ACTION_SET
114+
probs .+= typemin(eltype(probs)) .* (1 .- get_legal_actions_mask(env))
115+
end
116+
probs
117+
end
114118

115119
function RLBase.update!(learner::PrioritizedDQNLearner, batch::NamedTuple)
116120
Q, Qₜ, γ, β, loss_func, update_horizon, batch_size = learner.approximator,
@@ -132,11 +136,16 @@ function RLBase.update!(learner::PrioritizedDQNLearner, batch::NamedTuple)
132136
weights ./= maximum(weights)
133137
weights = send_to_device(D, weights)
134138

139+
target_q = Qₜ(next_states)
140+
if !isnothing(batch.next_legal_actions_mask)
141+
target_q .+= typemin(eltype(target_q)) .* (1 .- send_to_device(D, batch.next_legal_actions_mask))
142+
end
143+
144+
q′ = dropdims(maximum(target_q; dims = 1), dims = 1)
145+
G = rewards .+ γ^update_horizon .* (1 .- terminals) .* q′
146+
135147
gs = gradient(params(Q)) do
136148
q = Q(states)[actions]
137-
q′ = dropdims(maximum(Qₜ(next_states); dims = 1), dims = 1)
138-
G = rewards .+ γ^update_horizon .* (1 .- terminals) .* q′
139-
140149
batch_losses = loss_func(G, q)
141150
loss = dot(vec(weights), vec(batch_losses)) * 1 // batch_size
142151
ignore() do

src/algorithms/dqns/rainbow.jl

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -126,47 +126,43 @@ function (learner::RainbowLearner)(env)
126126
state = Flux.unsqueeze(state, ndims(state) + 1)
127127
logits = learner.approximator(state)
128128
q = learner.support .* softmax(reshape(logits, :, learner.n_actions))
129-
# probs = vec(sum(q, dims=1)) .+ legal_action
130-
vec(sum(q, dims = 1)) |> send_to_host
129+
probs = vec(sum(q, dims = 1)) |> send_to_host
130+
if ActionStyle(env) === FULL_ACTION_SET
131+
probs .+= typemin(eltype(probs)) .* (1 .- get_legal_actions_mask(env))
132+
end
133+
probs
131134
end
132135

133136
function RLBase.update!(learner::RainbowLearner, batch::NamedTuple)
134-
Q,
135-
Qₜ,
136-
γ,
137-
β,
138-
loss_func,
139-
n_atoms,
140-
n_actions,
141-
support,
142-
delta_z,
143-
update_horizon,
144-
batch_size = learner.approximator,
145-
learner.target_approximator,
146-
learner.γ,
147-
learner.β_priority,
148-
learner.loss_func,
149-
learner.n_atoms,
150-
learner.n_actions,
151-
learner.support,
152-
learner.delta_z,
153-
learner.update_horizon,
154-
learner.batch_size
155-
137+
Q = learner.approximator
138+
Qₜ = learner.target_approximator
139+
γ = learner.γ
140+
β = learner.β_priority
141+
loss_func = learner.loss_func
142+
n_atoms = learner.n_atoms
143+
n_actions = learner.n_actions
144+
support = learner.support
145+
delta_z = learner.delta_z
146+
update_horizon = learner.update_horizon
147+
batch_size = learner.batch_size
156148
D = device(Q)
157-
states, rewards, terminals, next_states = map(
158-
x -> send_to_device(D, x),
159-
(batch.states, batch.rewards, batch.terminals, batch.next_states),
160-
)
149+
states = send_to_device(D, batch.states)
150+
rewards = send_to_device(D, batch.rewards)
151+
terminals = send_to_device(D, batch.terminals)
152+
next_states = send_to_device(D, batch.next_states)
153+
161154
actions = CartesianIndex.(batch.actions, 1:batch_size)
155+
162156
target_support =
163157
reshape(rewards, 1, :) .+
164158
(reshape(support, :, 1) * reshape((γ^update_horizon) .* (1 .- terminals), 1, :))
165159

166160
next_logits = Qₜ(next_states)
167161
next_probs = reshape(softmax(reshape(next_logits, n_atoms, :)), n_atoms, n_actions, :)
168162
next_q = reshape(sum(support .* next_probs, dims = 1), n_actions, :)
169-
# next_q_argmax = argmax(cpu(next_q .+ next_legal_actions), dims=1)
163+
if !isnothing(batch.next_legal_actions_mask)
164+
next_q .+= typemin(eltype(next_q)) .* (1 .- send_to_device(D, batch.next_legal_actions_mask))
165+
end
170166
next_prob_select = select_best_probs(next_probs, next_q)
171167

172168
target_distribution = project_distribution(
@@ -178,18 +174,23 @@ function RLBase.update!(learner::RainbowLearner, batch::NamedTuple)
178174
learner.Vₘₐₓ,
179175
)
180176

181-
updated_priorities = Vector{Float32}(undef, batch_size)
182-
weights = 1f0 ./ ((batch.priorities .+ 1f-10) .^ β)
183-
weights ./= maximum(weights)
184-
weights = send_to_device(D, weights)
177+
is_use_PER = !isnothing(batch.priorities) # is use Prioritized Experience Replay
178+
if is_use_PER
179+
updated_priorities = Vector{Float32}(undef, batch_size)
180+
weights = 1f0 ./ ((batch.priorities .+ 1f-10) .^ β)
181+
weights ./= maximum(weights)
182+
weights = send_to_device(D, weights)
183+
end
185184

186185
gs = gradient(Flux.params(Q)) do
187186
logits = reshape(Q(states), n_atoms, n_actions, :)
188187
select_logits = logits[:, actions]
189188
batch_losses = loss_func(select_logits, target_distribution)
190-
loss = dot(vec(weights), vec(batch_losses)) * 1 // batch_size
189+
loss = is_use_PER ? dot(vec(weights), vec(batch_losses)) * 1 // batch_size : mean(batch_losses)
191190
ignore() do
192-
updated_priorities .= send_to_host(vec((batch_losses .+ 1f-10) .^ β))
191+
if is_use_PER
192+
updated_priorities .= send_to_host(vec((batch_losses .+ 1f-10) .^ β))
193+
end
193194
learner.loss = loss
194195
end
195196
loss

0 commit comments

Comments
 (0)