Skip to content

Commit 439df83

Browse files
committed
Fix Hamiltonians in sghmc.jl / DynamicHMCExt
1 parent 05303ec commit 439df83

File tree

3 files changed

+81
-115
lines changed

3 files changed

+81
-115
lines changed

ext/TuringDynamicHMCExt.jl

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ State of the [`DynamicNUTS`](@ref) sampler.
3535
# Fields
3636
$(TYPEDFIELDS)
3737
"""
38-
struct DynamicNUTSState{L,V<:DynamicPPL.AbstractVarInfo,C,M,S}
39-
logdensity::L
38+
struct DynamicNUTSState{V<:DynamicPPL.AbstractVarInfo,C,M,S}
4039
vi::V
4140
"Cache of sample, log density, and gradient of log density evaluation."
4241
cache::C
@@ -48,30 +47,17 @@ function DynamicPPL.initialsampler(::DynamicPPL.Sampler{<:DynamicNUTS})
4847
return DynamicPPL.SampleFromUniform()
4948
end
5049

51-
function DynamicPPL.initialstep(
50+
function AbstractMCMC.step(
5251
rng::Random.AbstractRNG,
53-
model::DynamicPPL.Model,
54-
spl::DynamicPPL.Sampler{<:DynamicNUTS},
55-
vi::DynamicPPL.AbstractVarInfo;
52+
ldf::DynamicPPL.LogDensityFunction,
53+
spl::DynamicPPL.Sampler{<:DynamicNUTS};
5654
kwargs...,
5755
)
58-
# Ensure that initial sample is in unconstrained space.
59-
if !DynamicPPL.islinked(vi)
60-
vi = DynamicPPL.link!!(vi, model)
61-
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl)))
62-
end
63-
64-
# Define log-density function.
65-
= DynamicPPL.LogDensityFunction(
66-
model,
67-
vi,
68-
DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext());
69-
adtype=spl.alg.adtype,
70-
)
56+
vi = ldf.varinfo
7157

7258
# Perform initial step.
7359
results = DynamicHMC.mcmc_keep_warmup(
74-
rng, , 0; initialization=(q=vi[:],), reporter=DynamicHMC.NoProgressReport()
60+
rng, ldf, 0; initialization=(q=vi[:],), reporter=DynamicHMC.NoProgressReport()
7561
)
7662
steps = DynamicHMC.mcmc_steps(results.sampling_logdensity, results.final_warmup_state)
7763
Q, _ = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q)
@@ -81,32 +67,31 @@ function DynamicPPL.initialstep(
8167
vi = DynamicPPL.setlogp!!(vi, Q.ℓq)
8268

8369
# Create first sample and state.
84-
sample = Turing.Inference.Transition(model, vi)
85-
state = DynamicNUTSState(ℓ, vi, Q, steps.H.κ, steps.ϵ)
70+
sample = Turing.Inference.Transition(ldf.model, vi)
71+
state = DynamicNUTSState(vi, Q, steps.H.κ, steps.ϵ)
8672

8773
return sample, state
8874
end
8975

9076
function AbstractMCMC.step(
9177
rng::Random.AbstractRNG,
92-
model::DynamicPPL.Model,
78+
ldf::DynamicPPL.LogDensityFunction,
9379
spl::DynamicPPL.Sampler{<:DynamicNUTS},
9480
state::DynamicNUTSState;
9581
kwargs...,
9682
)
9783
# Compute next sample.
9884
vi = state.vi
99-
= state.logdensity
100-
steps = DynamicHMC.mcmc_steps(rng, spl.alg.sampler, state.metric, ℓ, state.stepsize)
85+
steps = DynamicHMC.mcmc_steps(rng, spl.alg.sampler, state.metric, ldf, state.stepsize)
10186
Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache)
10287

10388
# Update the variables.
10489
vi = DynamicPPL.unflatten(vi, Q.q)
10590
vi = DynamicPPL.setlogp!!(vi, Q.ℓq)
10691

10792
# Create next sample and state.
108-
sample = Turing.Inference.Transition(model, vi)
109-
newstate = DynamicNUTSState(ℓ, vi, Q, state.metric, state.stepsize)
93+
sample = Turing.Inference.Transition(ldf.model, vi)
94+
newstate = DynamicNUTSState(vi, Q, state.metric, state.stepsize)
11095

11196
return sample, newstate
11297
end

src/mcmc/hmc.jl

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,44 @@
1+
# InferenceAlgorithm interface
2+
13
abstract type Hamiltonian <: InferenceAlgorithm end
4+
5+
DynamicPPL.initialsampler(::Sampler{<:Hamiltonian}) = SampleFromUniform()
6+
requires_unconstrained_space(::Hamiltonian) = true
7+
# TODO(penelopeysm): This is really quite dangerous code because it implicitly
8+
# assumes that any concrete type that subtypes `Hamiltonian` has an adtype
9+
# field.
10+
get_adtype(alg::Hamiltonian) = alg.adtype
11+
212
abstract type StaticHamiltonian <: Hamiltonian end
313
abstract type AdaptiveHamiltonian <: Hamiltonian end
414

15+
function update_sample_kwargs(alg::AdaptiveHamiltonian, N::Integer, kwargs)
16+
resume_from = get(kwargs, :resume_from, nothing)
17+
nadapts = get(kwargs, :nadapts, alg.n_adapts)
18+
discard_adapt = get(kwargs, :discard_adapt, true)
19+
discard_initial = get(kwargs, :discard_initial, -1)
20+
21+
return if resume_from === nothing
22+
# If `nadapts` is `-1`, then the user called a convenience constructor
23+
# like `NUTS()` or `NUTS(0.65)`, and we should set a default for them.
24+
if nadapts == -1
25+
_nadapts = min(1000, N ÷ 2) # Default to 1000 if not specified
26+
else
27+
_nadapts = nadapts
28+
end
29+
# If `discard_initial` is `-1`, then users did not specify the keyword argument.
30+
if discard_initial == -1
31+
_discard_initial = discard_adapt ? _nadapts : 0
32+
else
33+
_discard_initial = discard_initial
34+
end
35+
36+
(nadapts=_nadapts, discard_initial=_discard_initial, kwargs...)
37+
else
38+
(nadapts=0, discard_adapt=false, discard_initial=0, kwargs...)
39+
end
40+
end
41+
542
###
643
### Sampler states
744
###
@@ -80,37 +117,6 @@ function HMC(
80117
return HMC(ϵ, n_leapfrog, metricT; adtype=adtype)
81118
end
82119

83-
DynamicPPL.initialsampler(::Sampler{<:Hamiltonian}) = SampleFromUniform()
84-
85-
get_adtype(alg::Hamiltonian) = alg.adtype
86-
87-
function update_sample_kwargs(alg::AdaptiveHamiltonian, N::Integer, kwargs)
88-
resume_from = get(kwargs, :resume_from, nothing)
89-
nadapts = get(kwargs, :nadapts, alg.n_adapts)
90-
discard_adapt = get(kwargs, :discard_adapt, true)
91-
discard_initial = get(kwargs, :discard_initial, -1)
92-
93-
return if resume_from === nothing
94-
# If `nadapts` is `-1`, then the user called a convenience constructor
95-
# like `NUTS()` or `NUTS(0.65)`, and we should set a default for them.
96-
if nadapts == -1
97-
_nadapts = min(1000, N ÷ 2) # Default to 1000 if not specified
98-
else
99-
_nadapts = nadapts
100-
end
101-
# If `discard_initial` is `-1`, then users did not specify the keyword argument.
102-
if discard_initial == -1
103-
_discard_initial = discard_adapt ? _nadapts : 0
104-
else
105-
_discard_initial = discard_initial
106-
end
107-
108-
(nadapts=_nadapts, discard_initial=_discard_initial, kwargs...)
109-
else
110-
(nadapts=0, discard_adapt=false, discard_initial=0, kwargs...)
111-
end
112-
end
113-
114120
function find_initial_params(
115121
rng::Random.AbstractRNG,
116122
model::DynamicPPL.Model,
@@ -168,7 +174,7 @@ function AbstractMCMC.step(
168174
vi, z = if initial_params === nothing
169175
find_initial_params(rng, ldf.model, ldf.varinfo, hamiltonian)
170176
else
171-
vi, AHMC.phasepoint(rng, theta, hamiltonian)
177+
ldf.varinfo, AHMC.phasepoint(rng, theta, hamiltonian)
172178
end
173179
theta = vi[:]
174180

@@ -425,9 +431,9 @@ function NUTS(; kwargs...)
425431
return NUTS(-1, 0.65; kwargs...)
426432
end
427433

428-
for alg in (:HMC, :HMCDA, :NUTS)
429-
@eval getmetricT(::$alg{<:Any,metricT}) where {metricT} = metricT
430-
end
434+
getmetricT(::HMC{<:Any,metricT}) where {metricT} = metricT
435+
getmetricT(::HMCDA{<:Any,metricT}) where {metricT} = metricT
436+
getmetricT(::NUTS{<:Any,metricT}) where {metricT} = metricT
431437

432438
#####
433439
##### HMC core functions

src/mcmc/sghmc.jl

Lines changed: 28 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -45,50 +45,37 @@ function SGHMC(;
4545
return SGHMC(_learning_rate, _momentum_decay, adtype)
4646
end
4747

48-
struct SGHMCState{L,V<:AbstractVarInfo,T<:AbstractVector{<:Real}}
49-
logdensity::L
48+
struct SGHMCState{V<:AbstractVarInfo,T<:AbstractVector{<:Real}}
5049
vi::V
5150
velocity::T
5251
end
5352

54-
function DynamicPPL.initialstep(
53+
function AbstractMCMC.step(
5554
rng::Random.AbstractRNG,
56-
model::Model,
57-
spl::Sampler{<:SGHMC},
58-
vi::AbstractVarInfo;
55+
ldf::DynamicPPL.LogDensityFunction,
56+
spl::Sampler{<:SGHMC};
5957
kwargs...,
6058
)
61-
# Transform the samples to unconstrained space and compute the joint log probability.
62-
if !DynamicPPL.islinked(vi)
63-
vi = DynamicPPL.link!!(vi, model)
64-
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl)))
65-
end
59+
vi = ldf.varinfo
6660

6761
# Compute initial sample and state.
68-
sample = Transition(model, vi)
69-
= DynamicPPL.LogDensityFunction(
70-
model,
71-
vi,
72-
DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext());
73-
adtype=spl.alg.adtype,
74-
)
75-
state = SGHMCState(ℓ, vi, zero(vi[:]))
62+
sample = Transition(ldf.model, vi)
63+
state = SGHMCState(vi, zero(vi[:]))
7664

7765
return sample, state
7866
end
7967

8068
function AbstractMCMC.step(
8169
rng::Random.AbstractRNG,
82-
model::Model,
70+
ldf::DynamicPPL.LogDensityFunction,
8371
spl::Sampler{<:SGHMC},
8472
state::SGHMCState;
8573
kwargs...,
8674
)
8775
# Compute gradient of log density.
88-
= state.logdensity
8976
vi = state.vi
9077
θ = vi[:]
91-
grad = last(LogDensityProblems.logdensity_and_gradient(, θ))
78+
grad = last(LogDensityProblems.logdensity_and_gradient(ldf, θ))
9279

9380
# Update latent variables and velocity according to
9481
# equation (15) of Chen et al. (2014)
@@ -100,11 +87,11 @@ function AbstractMCMC.step(
10087

10188
# Save new variables and recompute log density.
10289
vi = DynamicPPL.unflatten(vi, θ)
103-
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl)))
90+
vi = last(DynamicPPL.evaluate!!(ldf.model, vi, DynamicPPL.SamplingContext(rng, spl)))
10491

10592
# Compute next sample and state.
106-
sample = Transition(model, vi)
107-
newstate = SGHMCState(ℓ, vi, newv)
93+
sample = Transition(ldf.model, vi)
94+
newstate = SGHMCState(vi, newv)
10895

10996
return sample, newstate
11097
end
@@ -208,57 +195,45 @@ metadata(t::SGLDTransition) = (lp=t.lp, SGLD_stepsize=t.stepsize)
208195

209196
DynamicPPL.getlogp(t::SGLDTransition) = t.lp
210197

211-
struct SGLDState{L,V<:AbstractVarInfo}
212-
logdensity::L
198+
struct SGLDState{V<:AbstractVarInfo}
213199
vi::V
214200
step::Int
215201
end
216202

217-
function DynamicPPL.initialstep(
203+
function AbstractMCMC.step(
218204
rng::Random.AbstractRNG,
219-
model::Model,
220-
spl::Sampler{<:SGLD},
221-
vi::AbstractVarInfo;
205+
ldf::DynamicPPL.LogDensityFunction,
206+
spl::Sampler{<:SGLD};
222207
kwargs...,
223208
)
224-
# Transform the samples to unconstrained space and compute the joint log probability.
225-
if !DynamicPPL.islinked(vi)
226-
vi = DynamicPPL.link!!(vi, model)
227-
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl)))
228-
end
229-
230209
# Create first sample and state.
231-
sample = SGLDTransition(model, vi, zero(spl.alg.stepsize(0)))
232-
= DynamicPPL.LogDensityFunction(
233-
model,
234-
vi,
235-
DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext());
236-
adtype=spl.alg.adtype,
237-
)
238-
state = SGLDState(ℓ, vi, 1)
239-
210+
vi = ldf.varinfo
211+
sample = SGLDTransition(ldf.model, vi, zero(spl.alg.stepsize(0)))
212+
state = SGLDState(vi, 1)
240213
return sample, state
241214
end
242215

243216
function AbstractMCMC.step(
244-
rng::Random.AbstractRNG, model::Model, spl::Sampler{<:SGLD}, state::SGLDState; kwargs...
217+
rng::Random.AbstractRNG,
218+
ldf::LogDensityFunction,
219+
spl::Sampler{<:SGLD},
220+
state::SGLDState;
221+
kwargs...,
245222
)
246223
# Perform gradient step.
247-
= state.logdensity
248224
vi = state.vi
249225
θ = vi[:]
250-
grad = last(LogDensityProblems.logdensity_and_gradient(, θ))
226+
grad = last(LogDensityProblems.logdensity_and_gradient(ldf, θ))
251227
step = state.step
252228
stepsize = spl.alg.stepsize(step)
253229
θ .+= (stepsize / 2) .* grad .+ sqrt(stepsize) .* randn(rng, eltype(θ), length(θ))
254230

255231
# Save new variables and recompute log density.
256232
vi = DynamicPPL.unflatten(vi, θ)
257-
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl)))
233+
vi = last(DynamicPPL.evaluate!!(ldf.model, vi, DynamicPPL.SamplingContext(rng, spl)))
258234

259235
# Compute next sample and state.
260-
sample = SGLDTransition(model, vi, stepsize)
261-
newstate = SGLDState(ℓ, vi, state.step + 1)
262-
236+
sample = SGLDTransition(ldf.model, vi, stepsize)
237+
newstate = SGLDState(vi, state.step + 1)
263238
return sample, newstate
264239
end

0 commit comments

Comments
 (0)