Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 2ddf949

Browse files
authored
Support rlintro (#144)
* rename TabularLearner to TabularRandomPolicy * sync chapter01 * sync changes related to RLIntro * add double q learner * add Value based TDLearner * add tabular dyna agent * add LinearApproximator * add TDλReturnLearner * sync * minor bugfix * fix tests * minor bugfix due to RLCore * bump version
1 parent 335662a commit 2ddf949

27 files changed

+1314
-213
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
name = "ReinforcementLearningZoo"
22
uuid = "d607f57d-ee1e-4ba7-bcf2-7734c1e31854"
33
authors = ["Jun Tian <tianjun.cpp@gmail.com>"]
4-
version = "0.3.1"
4+
version = "0.3.2"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
88
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
99
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1010
CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
11+
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1112
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1213
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1314
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
@@ -38,7 +39,7 @@ Flux = "0.11"
3839
IntervalSets = "0.5"
3940
MacroTools = "0.5"
4041
ReinforcementLearningBase = "0.9"
41-
ReinforcementLearningCore = "0.6.3"
42+
ReinforcementLearningCore = "0.7"
4243
Requires = "1"
4344
Setfield = "0.6, 0.7"
4445
StableRNGs = "1.0"

src/algorithms/algorithms.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
include("tabular/tabular.jl")
12
include("dqns/dqns.jl")
23
include("policy_gradient/policy_gradient.jl")
34
include("searching/searching.jl")

src/algorithms/cfr/external_sampling_mccfr.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ function ExternalSamplingMCCFRPolicy(; state_type = String, rng = Random.GLOBAL_
2929
TabularRandomPolicy(;
3030
rng = rng,
3131
table = Dict{state_type,Vector{Float64}}(),
32-
is_normalized = true,
3332
),
3433
rng,
3534
)
@@ -44,7 +43,7 @@ function RLBase.update!(p::ExternalSamplingMCCFRPolicy)
4443
strategy[m] .= v.cumulative_strategy ./ s
4544
update!(p.behavior_policy, k => strategy)
4645
else
47-
# The TabularLearner will return uniform distribution by default.
46+
# The TabularRandomPolicy will return uniform distribution by default.
4847
# So we do nothing here.
4948
end
5049
end

src/algorithms/cfr/outcome_sampling_mccfr.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ function OutcomeSamplingMCCFRPolicy(; state_type = String, rng = Random.GLOBAL_R
3030
TabularRandomPolicy(;
3131
rng = rng,
3232
table = Dict{state_type,Vector{Float64}}(),
33-
is_normalized = true,
3433
),
3534
ϵ,
3635
rng,
@@ -55,7 +54,7 @@ function RLBase.update!(p::OutcomeSamplingMCCFRPolicy)
5554
strategy[m] .= v.cumulative_strategy ./ s
5655
update!(p.behavior_policy, k => strategy)
5756
else
58-
# The TabularLearner will return uniform distribution by default.
57+
# The TabularRandomPolicy will return uniform distribution by default.
5958
# So we do nothing here.
6059
end
6160
end

src/algorithms/cfr/tabular_cfr.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ end
2323

2424
mutable struct TabularCFRPolicy{S,T,R<:AbstractRNG} <: AbstractCFRPolicy
2525
nodes::Dict{S,InfoStateNode}
26-
behavior_policy::QBasedPolicy{TabularLearner{S,T},WeightedExplorer{true,R}}
26+
behavior_policy::TabularRandomPolicy{S,T,R}
2727
is_reset_neg_regrets::Bool
2828
is_linear_averaging::Bool
2929
weighted_averaging_delay::Int
@@ -70,7 +70,6 @@ function TabularCFRPolicy(;
7070
TabularRandomPolicy(;
7171
rng = rng,
7272
table = Dict{state_type,Vector{Float64}}(),
73-
is_normalized = true,
7473
),
7574
is_reset_neg_regrets,
7675
is_linear_averaging,
@@ -91,7 +90,7 @@ function RLBase.update!(p::TabularCFRPolicy)
9190
strategy[m] .= v.cumulative_strategy ./ s
9291
update!(p.behavior_policy, k => strategy)
9392
else
94-
# The TabularLearner will return uniform distribution by default.
93+
# The TabularRandomPolicy will return uniform distribution by default.
9594
# So we do nothing here.
9695
end
9796
end

src/algorithms/policy_gradient/MAC.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ function (learner::MACLearner)(env)
3939
learner.approximator.actor(s) |> vec |> send_to_host
4040
end
4141

42-
function RLBase.update!(learner::MACLearner, t::CircularArraySARTTrajectory)
42+
function RLBase.update!(learner::MACLearner, t::CircularArraySARTTrajectory, ::AbstractEnv, ::PreActStage)
4343
length(t) == 0 && return # in the first update, only state & action is inserted into trajectory
4444
learner.update_step += 1
4545
if learner.update_step % learner.update_freq == 0

src/algorithms/policy_gradient/ddpg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ function (p::DDPGPolicy)(env)
119119
end
120120
end
121121

122-
function RLBase.update!(p::DDPGPolicy, traj::CircularArraySARTTrajectory)
122+
function RLBase.update!(p::DDPGPolicy, traj::CircularArraySARTTrajectory, ::AbstractEnv, ::PreActStage)
123123
length(traj) > p.update_after || return
124124
p.step % p.update_every == 0 || return
125125
inds, batch = sample(p.rng, traj, BatchSampler{SARTS}(p.batch_size))

src/algorithms/policy_gradient/ppo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ function (agent::Agent{<:RandomStartPolicy{<:PPOPolicy}})(env::AbstractEnv)
182182
end
183183
end
184184

185-
function RLBase.update!(p::PPOPolicy, t::Union{PPOTrajectory, MaskedPPOTrajectory})
185+
function RLBase.update!(p::PPOPolicy, t::Union{PPOTrajectory, MaskedPPOTrajectory}, ::AbstractEnv, ::PreActStage)
186186
length(t) == 0 && return # in the first update, only state & action is inserted into trajectory
187187
p.update_step += 1
188188
if p.update_step % p.update_freq == 0

src/algorithms/policy_gradient/sac.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ function evaluate(p::SACPolicy, state)
125125
return tanh.(z), logp_π
126126
end
127127

128-
function RLBase.update!(p::SACPolicy, traj::CircularArraySARTTrajectory)
128+
function RLBase.update!(p::SACPolicy, traj::CircularArraySARTTrajectory, ::AbstractEnv, ::PreActStage)
129129
length(traj) > p.update_after || return
130130
p.step % p.update_every == 0 || return
131131
inds, batch = sample(p.rng, traj, BatchSampler{SARTS}(p.batch_size))

src/algorithms/policy_gradient/td3.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ function (p::TD3Policy)(env)
126126
end
127127
end
128128

129-
function RLBase.update!(p::TD3Policy, traj::CircularArraySARTTrajectory)
129+
function RLBase.update!(p::TD3Policy, traj::CircularArraySARTTrajectory, ::AbstractEnv, ::PreActStage)
130130
length(traj) > p.update_after || return
131131
p.step % p.update_every == 0 || return
132132
inds, batch = sample(p.rng, traj, BatchSampler{SARTS}(p.batch_size))

0 commit comments

Comments
 (0)