|
82 | 82 |
|
83 | 83 | DynamicPPL.initialsampler(::Sampler{<:Hamiltonian}) = SampleFromUniform()
|
84 | 84 |
|
85 |
| -# Handle setting `nadapts` and `discard_initial` |
86 |
| -function AbstractMCMC.sample( |
87 |
| - rng::AbstractRNG, |
88 |
| - model::DynamicPPL.Model, |
89 |
| - sampler::Sampler{<:AdaptiveHamiltonian}, |
90 |
| - N::Integer; |
91 |
| - chain_type=DynamicPPL.default_chain_type(sampler), |
92 |
| - resume_from=nothing, |
93 |
| - initial_state=DynamicPPL.loadstate(resume_from), |
94 |
| - progress=PROGRESS[], |
95 |
| - nadapts=sampler.alg.n_adapts, |
96 |
| - discard_adapt=true, |
97 |
| - discard_initial=-1, |
98 |
| - kwargs..., |
99 |
| -) |
100 |
| - if resume_from === nothing |
101 |
| - # If `nadapts` is `-1`, then the user called a convenience |
102 |
| - # constructor like `NUTS()` or `NUTS(0.65)`, |
103 |
| - # and we should set a default for them. |
| 85 | +get_adtype(alg::Hamiltonian) = alg.adtype |
| 86 | + |
| 87 | +function update_sample_kwargs(alg::AdaptiveHamiltonian, N::Integer, kwargs) |
| 88 | + resume_from = get(kwargs, :resume_from, nothing) |
| 89 | + nadapts = get(kwargs, :nadapts, alg.n_adapts) |
| 90 | + discard_adapt = get(kwargs, :discard_adapt, true) |
| 91 | + discard_initial = get(kwargs, :discard_initial, -1) |
| 92 | + |
| 93 | + return if resume_from === nothing |
| 94 | + # If `nadapts` is `-1`, then the user called a convenience constructor |
| 95 | + # like `NUTS()` or `NUTS(0.65)`, and we should set a default for them. |
104 | 96 | if nadapts == -1
|
105 |
| - _nadapts = min(1000, N ÷ 2) |
| 97 | + _nadapts = min(1000, N ÷ 2) # Default to 1000 if not specified |
106 | 98 | else
|
107 | 99 | _nadapts = nadapts
|
108 | 100 | end
|
109 |
| - |
110 | 101 | # If `discard_initial` is `-1`, then users did not specify the keyword argument.
|
111 | 102 | if discard_initial == -1
|
112 | 103 | _discard_initial = discard_adapt ? _nadapts : 0
|
113 | 104 | else
|
114 | 105 | _discard_initial = discard_initial
|
115 | 106 | end
|
116 | 107 |
|
117 |
| - return AbstractMCMC.mcmcsample( |
118 |
| - rng, |
119 |
| - model, |
120 |
| - sampler, |
121 |
| - N; |
122 |
| - chain_type=chain_type, |
123 |
| - progress=progress, |
124 |
| - nadapts=_nadapts, |
125 |
| - discard_initial=_discard_initial, |
126 |
| - kwargs..., |
127 |
| - ) |
| 108 | + (nadapts=_nadapts, discard_initial=_discard_initial, kwargs...) |
128 | 109 | else
|
129 |
| - return AbstractMCMC.mcmcsample( |
130 |
| - rng, |
131 |
| - model, |
132 |
| - sampler, |
133 |
| - N; |
134 |
| - chain_type=chain_type, |
135 |
| - initial_state=initial_state, |
136 |
| - progress=progress, |
137 |
| - nadapts=0, |
138 |
| - discard_adapt=false, |
139 |
| - discard_initial=0, |
140 |
| - kwargs..., |
141 |
| - ) |
| 110 | + (nadapts=0, discard_adapt=false, discard_initial=0, kwargs...) |
142 | 111 | end
|
143 | 112 | end
|
144 | 113 |
|
@@ -172,42 +141,32 @@ function find_initial_params(
|
172 | 141 | )
|
173 | 142 | end
|
174 | 143 |
|
175 |
| -function DynamicPPL.initialstep( |
| 144 | +function AbstractMCMC.step( |
176 | 145 | rng::AbstractRNG,
|
177 |
| - model::AbstractModel, |
178 |
| - spl::Sampler{<:Hamiltonian}, |
179 |
| - vi_original::AbstractVarInfo; |
| 146 | + ldf::LogDensityFunction, |
| 147 | + spl::Sampler{<:Hamiltonian}; |
180 | 148 | initial_params=nothing,
|
181 | 149 | nadapts=0,
|
182 | 150 | kwargs...,
|
183 | 151 | )
|
184 |
| - # Transform the samples to unconstrained space and compute the joint log probability. |
185 |
| - vi = DynamicPPL.link(vi_original, model) |
| 152 | + ldf.adtype === nothing && |
| 153 | + error("Hamiltonian sampler received a LogDensityFunction without an AD backend") |
186 | 154 |
|
187 |
| - # Extract parameters. |
188 |
| - theta = vi[:] |
| 155 | + theta = ldf.varinfo[:] |
| 156 | + |
| 157 | + has_initial_params = initial_params !== nothing |
189 | 158 |
|
190 | 159 | # Create a Hamiltonian.
|
191 | 160 | metricT = getmetricT(spl.alg)
|
192 | 161 | metric = metricT(length(theta))
|
193 |
| - ldf = DynamicPPL.LogDensityFunction( |
194 |
| - model, |
195 |
| - vi, |
196 |
| - # TODO(penelopeysm): Can we just use leafcontext(model.context)? Do we |
197 |
| - # need to pass in the sampler? (In fact LogDensityFunction defaults to |
198 |
| - # using leafcontext(model.context) so could we just remove the argument |
199 |
| - # entirely?) |
200 |
| - DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)); |
201 |
| - adtype=spl.alg.adtype, |
202 |
| - ) |
203 | 162 | lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf)
|
204 | 163 | lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf)
|
205 | 164 | hamiltonian = AHMC.Hamiltonian(metric, lp_func, lp_grad_func)
|
206 | 165 |
|
207 | 166 | # If no initial parameters are provided, resample until the log probability
|
208 | 167 | # and its gradient are finite. Otherwise, just use the existing parameters.
|
209 | 168 | vi, z = if initial_params === nothing
|
210 |
| - find_initial_params(rng, model, vi, hamiltonian) |
| 169 | + find_initial_params(rng, ldf.model, ldf.varinfo, hamiltonian) |
211 | 170 | else
|
212 | 171 | vi, AHMC.phasepoint(rng, theta, hamiltonian)
|
213 | 172 | end
|
@@ -248,23 +207,20 @@ function DynamicPPL.initialstep(
|
248 | 207 | vi = setlogp!!(vi, log_density_old)
|
249 | 208 | end
|
250 | 209 |
|
251 |
| - transition = Transition(model, vi, t) |
| 210 | + transition = Transition(ldf.model, vi, t) |
252 | 211 | state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor)
|
253 | 212 |
|
254 | 213 | return transition, state
|
255 | 214 | end
|
256 | 215 |
|
257 | 216 | function AbstractMCMC.step(
|
258 | 217 | rng::Random.AbstractRNG,
|
259 |
| - model::Model, |
| 218 | + ldf::LogDensityFunction, |
260 | 219 | spl::Sampler{<:Hamiltonian},
|
261 | 220 | state::HMCState;
|
262 | 221 | nadapts=0,
|
263 | 222 | kwargs...,
|
264 | 223 | )
|
265 |
| - # Get step size |
266 |
| - @debug "current ϵ" getstepsize(spl, state) |
267 |
| - |
268 | 224 | # Compute transition.
|
269 | 225 | hamiltonian = state.hamiltonian
|
270 | 226 | z = state.z
|
@@ -294,13 +250,15 @@ function AbstractMCMC.step(
|
294 | 250 | end
|
295 | 251 |
|
296 | 252 | # Compute next transition and state.
|
297 |
| - transition = Transition(model, vi, t) |
| 253 | + transition = Transition(ldf.model, vi, t) |
298 | 254 | newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor)
|
299 | 255 |
|
300 | 256 | return transition, newstate
|
301 | 257 | end
|
302 | 258 |
|
303 | 259 | function get_hamiltonian(model, spl, vi, state, n)
|
| 260 | + # TODO(penelopeysm): This is used by the Gibbs sampler, we can |
| 261 | + # simplify it to use LDF when Gibbs is reworked |
304 | 262 | metric = gen_metric(n, spl, state)
|
305 | 263 | ldf = DynamicPPL.LogDensityFunction(
|
306 | 264 | model,
|
|
0 commit comments