Skip to content

Commit 9d4dbf3

Browse files
committed
Update AbstractMCMC interface for Hamiltonian samplers
1 parent 088658e commit 9d4dbf3

File tree

1 file changed

+28
-70
lines changed

1 file changed

+28
-70
lines changed

src/mcmc/hmc.jl

Lines changed: 28 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -82,63 +82,32 @@ end
8282

8383
DynamicPPL.initialsampler(::Sampler{<:Hamiltonian}) = SampleFromUniform()
8484

85-
# Handle setting `nadapts` and `discard_initial`
86-
function AbstractMCMC.sample(
87-
rng::AbstractRNG,
88-
model::DynamicPPL.Model,
89-
sampler::Sampler{<:AdaptiveHamiltonian},
90-
N::Integer;
91-
chain_type=DynamicPPL.default_chain_type(sampler),
92-
resume_from=nothing,
93-
initial_state=DynamicPPL.loadstate(resume_from),
94-
progress=PROGRESS[],
95-
nadapts=sampler.alg.n_adapts,
96-
discard_adapt=true,
97-
discard_initial=-1,
98-
kwargs...,
99-
)
100-
if resume_from === nothing
101-
# If `nadapts` is `-1`, then the user called a convenience
102-
# constructor like `NUTS()` or `NUTS(0.65)`,
103-
# and we should set a default for them.
85+
get_adtype(alg::Hamiltonian) = alg.adtype
86+
87+
function update_sample_kwargs(alg::AdaptiveHamiltonian, N::Integer, kwargs)
88+
resume_from = get(kwargs, :resume_from, nothing)
89+
nadapts = get(kwargs, :nadapts, alg.n_adapts)
90+
discard_adapt = get(kwargs, :discard_adapt, true)
91+
discard_initial = get(kwargs, :discard_initial, -1)
92+
93+
return if resume_from === nothing
94+
# If `nadapts` is `-1`, then the user called a convenience constructor
95+
# like `NUTS()` or `NUTS(0.65)`, and we should set a default for them.
10496
if nadapts == -1
105-
_nadapts = min(1000, N ÷ 2)
97+
_nadapts = min(1000, N ÷ 2) # Default to 1000 if not specified
10698
else
10799
_nadapts = nadapts
108100
end
109-
110101
# If `discard_initial` is `-1`, then users did not specify the keyword argument.
111102
if discard_initial == -1
112103
_discard_initial = discard_adapt ? _nadapts : 0
113104
else
114105
_discard_initial = discard_initial
115106
end
116107

117-
return AbstractMCMC.mcmcsample(
118-
rng,
119-
model,
120-
sampler,
121-
N;
122-
chain_type=chain_type,
123-
progress=progress,
124-
nadapts=_nadapts,
125-
discard_initial=_discard_initial,
126-
kwargs...,
127-
)
108+
(nadapts=_nadapts, discard_initial=_discard_initial, kwargs...)
128109
else
129-
return AbstractMCMC.mcmcsample(
130-
rng,
131-
model,
132-
sampler,
133-
N;
134-
chain_type=chain_type,
135-
initial_state=initial_state,
136-
progress=progress,
137-
nadapts=0,
138-
discard_adapt=false,
139-
discard_initial=0,
140-
kwargs...,
141-
)
110+
(nadapts=0, discard_adapt=false, discard_initial=0, kwargs...)
142111
end
143112
end
144113

@@ -172,42 +141,32 @@ function find_initial_params(
172141
)
173142
end
174143

175-
function DynamicPPL.initialstep(
144+
function AbstractMCMC.step(
176145
rng::AbstractRNG,
177-
model::AbstractModel,
178-
spl::Sampler{<:Hamiltonian},
179-
vi_original::AbstractVarInfo;
146+
ldf::LogDensityFunction,
147+
spl::Sampler{<:Hamiltonian};
180148
initial_params=nothing,
181149
nadapts=0,
182150
kwargs...,
183151
)
184-
# Transform the samples to unconstrained space and compute the joint log probability.
185-
vi = DynamicPPL.link(vi_original, model)
152+
ldf.adtype === nothing &&
153+
error("Hamiltonian sampler received a LogDensityFunction without an AD backend")
186154

187-
# Extract parameters.
188-
theta = vi[:]
155+
theta = ldf.varinfo[:]
156+
157+
has_initial_params = initial_params !== nothing
189158

190159
# Create a Hamiltonian.
191160
metricT = getmetricT(spl.alg)
192161
metric = metricT(length(theta))
193-
ldf = DynamicPPL.LogDensityFunction(
194-
model,
195-
vi,
196-
# TODO(penelopeysm): Can we just use leafcontext(model.context)? Do we
197-
# need to pass in the sampler? (In fact LogDensityFunction defaults to
198-
# using leafcontext(model.context) so could we just remove the argument
199-
# entirely?)
200-
DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context));
201-
adtype=spl.alg.adtype,
202-
)
203162
lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf)
204163
lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf)
205164
hamiltonian = AHMC.Hamiltonian(metric, lp_func, lp_grad_func)
206165

207166
# If no initial parameters are provided, resample until the log probability
208167
# and its gradient are finite. Otherwise, just use the existing parameters.
209168
vi, z = if initial_params === nothing
210-
find_initial_params(rng, model, vi, hamiltonian)
169+
find_initial_params(rng, ldf.model, ldf.varinfo, hamiltonian)
211170
else
212171
vi, AHMC.phasepoint(rng, theta, hamiltonian)
213172
end
@@ -248,23 +207,20 @@ function DynamicPPL.initialstep(
248207
vi = setlogp!!(vi, log_density_old)
249208
end
250209

251-
transition = Transition(model, vi, t)
210+
transition = Transition(ldf.model, vi, t)
252211
state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor)
253212

254213
return transition, state
255214
end
256215

257216
function AbstractMCMC.step(
258217
rng::Random.AbstractRNG,
259-
model::Model,
218+
ldf::LogDensityFunction,
260219
spl::Sampler{<:Hamiltonian},
261220
state::HMCState;
262221
nadapts=0,
263222
kwargs...,
264223
)
265-
# Get step size
266-
@debug "current ϵ" getstepsize(spl, state)
267-
268224
# Compute transition.
269225
hamiltonian = state.hamiltonian
270226
z = state.z
@@ -294,13 +250,15 @@ function AbstractMCMC.step(
294250
end
295251

296252
# Compute next transition and state.
297-
transition = Transition(model, vi, t)
253+
transition = Transition(ldf.model, vi, t)
298254
newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor)
299255

300256
return transition, newstate
301257
end
302258

303259
function get_hamiltonian(model, spl, vi, state, n)
260+
# TODO(penelopeysm): This is used by the Gibbs sampler, we can
261+
# simplify it to use LDF when Gibbs is reworked
304262
metric = gen_metric(n, spl, state)
305263
ldf = DynamicPPL.LogDensityFunction(
306264
model,

0 commit comments

Comments
 (0)