Skip to content

Commit e4cb590

Browse files
committed
Unwrap the other Hamiltonians to make them AbstractSamplers
1 parent 49f6988 commit e4cb590

File tree

5 files changed

+13
-44
lines changed

5 files changed

+13
-44
lines changed

ext/TuringDynamicHMCExt.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,12 @@ struct DynamicNUTSState{V<:DynamicPPL.AbstractVarInfo,C,M,S}
4343
stepsize::S
4444
end
4545

46-
function DynamicPPL.initialsampler(::DynamicPPL.Sampler{<:DynamicNUTS})
46+
function DynamicPPL.initialsampler(::DynamicNUTS)
4747
return DynamicPPL.SampleFromUniform()
4848
end
4949

5050
function AbstractMCMC.step(
51-
rng::Random.AbstractRNG,
52-
ldf::DynamicPPL.LogDensityFunction,
53-
spl::DynamicPPL.Sampler{<:DynamicNUTS};
54-
kwargs...,
51+
rng::Random.AbstractRNG, ldf::DynamicPPL.LogDensityFunction, spl::DynamicNUTS; kwargs...
5552
)
5653
vi = ldf.varinfo
5754

@@ -76,13 +73,13 @@ end
7673
function AbstractMCMC.step(
7774
rng::Random.AbstractRNG,
7875
ldf::DynamicPPL.LogDensityFunction,
79-
spl::DynamicPPL.Sampler{<:DynamicNUTS},
76+
spl::DynamicNUTS,
8077
state::DynamicNUTSState;
8178
kwargs...,
8279
)
8380
# Compute next sample.
8481
vi = state.vi
85-
steps = DynamicHMC.mcmc_steps(rng, spl.alg.sampler, state.metric, ldf, state.stepsize)
82+
steps = DynamicHMC.mcmc_steps(rng, spl.sampler, state.metric, ldf, state.stepsize)
8683
Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache)
8784

8885
# Update the variables.

src/mcmc/sghmc.jl

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,7 @@ struct SGHMCState{V<:AbstractVarInfo,T<:AbstractVector{<:Real}}
5151
end
5252

5353
function AbstractMCMC.step(
54-
rng::Random.AbstractRNG,
55-
ldf::DynamicPPL.LogDensityFunction,
56-
spl::Sampler{<:SGHMC};
57-
kwargs...,
54+
rng::Random.AbstractRNG, ldf::DynamicPPL.LogDensityFunction, spl::SGHMC; kwargs...
5855
)
5956
vi = ldf.varinfo
6057

@@ -68,7 +65,7 @@ end
6865
function AbstractMCMC.step(
6966
rng::Random.AbstractRNG,
7067
ldf::DynamicPPL.LogDensityFunction,
71-
spl::Sampler{<:SGHMC},
68+
spl::SGHMC,
7269
state::SGHMCState;
7370
kwargs...,
7471
)
@@ -81,8 +78,8 @@ function AbstractMCMC.step(
8178
# equation (15) of Chen et al. (2014)
8279
v = state.velocity
8380
θ .+= v
84-
η = spl.alg.learning_rate
85-
α = spl.alg.momentum_decay
81+
η = spl.learning_rate
82+
α = spl.momentum_decay
8683
newv = (1 - α) .* v .+ η .* grad .+ sqrt(2 * η * α) .* randn(rng, eltype(v), length(v))
8784

8885
# Save new variables and recompute log density.
@@ -201,31 +198,24 @@ struct SGLDState{V<:AbstractVarInfo}
201198
end
202199

203200
function AbstractMCMC.step(
204-
rng::Random.AbstractRNG,
205-
ldf::DynamicPPL.LogDensityFunction,
206-
spl::Sampler{<:SGLD};
207-
kwargs...,
201+
rng::Random.AbstractRNG, ldf::DynamicPPL.LogDensityFunction, spl::SGLD; kwargs...
208202
)
209203
# Create first sample and state.
210204
vi = ldf.varinfo
211-
sample = SGLDTransition(ldf.model, vi, zero(spl.alg.stepsize(0)))
205+
sample = SGLDTransition(ldf.model, vi, zero(spl.stepsize(0)))
212206
state = SGLDState(vi, 1)
213207
return sample, state
214208
end
215209

216210
function AbstractMCMC.step(
217-
rng::Random.AbstractRNG,
218-
ldf::LogDensityFunction,
219-
spl::Sampler{<:SGLD},
220-
state::SGLDState;
221-
kwargs...,
211+
rng::Random.AbstractRNG, ldf::LogDensityFunction, spl::SGLD, state::SGLDState; kwargs...
222212
)
223213
# Perform gradient step.
224214
vi = state.vi
225215
θ = vi[:]
226216
grad = last(LogDensityProblems.logdensity_and_gradient(ldf, θ))
227217
step = state.step
228-
stepsize = spl.alg.stepsize(step)
218+
stepsize = spl.stepsize(step)
229219
θ .+= (stepsize / 2) .* grad .+ sqrt(stepsize) .* randn(rng, eltype(θ), length(θ))
230220

231221
# Save new variables and recompute log density.

test/ext/dynamichmc.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@ using Turing
1414
@testset "TuringDynamicHMCExt" begin
1515
spl = externalsampler(DynamicHMC.NUTS())
1616

17-
@testset "alg_str" begin
18-
@test DynamicPPL.alg_str(Sampler(spl)) == "DynamicNUTS"
19-
end
20-
2117
@testset "sample() interface" begin
2218
@model function demo_normal(x)
2319
a ~ Normal()

test/mcmc/repeat_sampler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ using Turing
1414
num_chains = 2
1515

1616
rng = StableRNG(0)
17-
for sampler in [MH(), DynamicPPL.Sampler(HMC(0.01, 4))]
17+
for sampler in [MH(), HMC(0.01, 4)]
1818
model_or_ldf = if sampler isa MH
1919
gdemo_default
2020
else

test/mcmc/sghmc.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,6 @@ end
7979
@testset "sghmc constructor" begin
8080
alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1)
8181
@test alg isa SGHMC
82-
sampler = DynamicPPL.Sampler(alg)
83-
@test sampler isa DynamicPPL.Sampler{<:SGHMC}
84-
85-
alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1)
86-
@test alg isa SGHMC
87-
sampler = DynamicPPL.Sampler(alg)
88-
@test sampler isa DynamicPPL.Sampler{<:SGHMC}
8982
end
9083

9184
@testset "sghmc inference" begin
@@ -100,13 +93,6 @@ end
10093
@testset "sgld constructor" begin
10194
alg = SGLD(; stepsize=PolynomialStepsize(0.25))
10295
@test alg isa SGLD
103-
sampler = DynamicPPL.Sampler(alg)
104-
@test sampler isa DynamicPPL.Sampler{<:SGLD}
105-
106-
alg = SGLD(; stepsize=PolynomialStepsize(0.25))
107-
@test alg isa SGLD
108-
sampler = DynamicPPL.Sampler(alg)
109-
@test sampler isa DynamicPPL.Sampler{<:SGLD}
11096
end
11197
@testset "sgld inference" begin
11298
rng = StableRNG(1)

0 commit comments

Comments
 (0)