@@ -45,50 +45,37 @@ function SGHMC(;
45
45
return SGHMC (_learning_rate, _momentum_decay, adtype)
46
46
end
47
47
48
- struct SGHMCState{L,V<: AbstractVarInfo ,T<: AbstractVector{<:Real} }
49
- logdensity:: L
48
+ struct SGHMCState{V<: AbstractVarInfo ,T<: AbstractVector{<:Real} }
50
49
vi:: V
51
50
velocity:: T
52
51
end
53
52
54
- function DynamicPPL . initialstep (
53
+ function AbstractMCMC . step (
55
54
rng:: Random.AbstractRNG ,
56
- model:: Model ,
57
- spl:: Sampler{<:SGHMC} ,
58
- vi:: AbstractVarInfo ;
55
+ ldf:: DynamicPPL.LogDensityFunction ,
56
+ spl:: Sampler{<:SGHMC} ;
59
57
kwargs... ,
60
58
)
61
- # Transform the samples to unconstrained space and compute the joint log probability.
62
- if ! DynamicPPL. islinked (vi)
63
- vi = DynamicPPL. link!! (vi, model)
64
- vi = last (DynamicPPL. evaluate!! (model, vi, DynamicPPL. SamplingContext (rng, spl)))
65
- end
59
+ vi = ldf. varinfo
66
60
67
61
# Compute initial sample and state.
68
- sample = Transition (model, vi)
69
- ℓ = DynamicPPL. LogDensityFunction (
70
- model,
71
- vi,
72
- DynamicPPL. SamplingContext (spl, DynamicPPL. DefaultContext ());
73
- adtype= spl. alg. adtype,
74
- )
75
- state = SGHMCState (ℓ, vi, zero (vi[:]))
62
+ sample = Transition (ldf. model, vi)
63
+ state = SGHMCState (vi, zero (vi[:]))
76
64
77
65
return sample, state
78
66
end
79
67
80
68
function AbstractMCMC. step (
81
69
rng:: Random.AbstractRNG ,
82
- model :: Model ,
70
+ ldf :: DynamicPPL.LogDensityFunction ,
83
71
spl:: Sampler{<:SGHMC} ,
84
72
state:: SGHMCState ;
85
73
kwargs... ,
86
74
)
87
75
# Compute gradient of log density.
88
- ℓ = state. logdensity
89
76
vi = state. vi
90
77
θ = vi[:]
91
- grad = last (LogDensityProblems. logdensity_and_gradient (ℓ , θ))
78
+ grad = last (LogDensityProblems. logdensity_and_gradient (ldf , θ))
92
79
93
80
# Update latent variables and velocity according to
94
81
# equation (15) of Chen et al. (2014)
@@ -100,11 +87,11 @@ function AbstractMCMC.step(
100
87
101
88
# Save new variables and recompute log density.
102
89
vi = DynamicPPL. unflatten (vi, θ)
103
- vi = last (DynamicPPL. evaluate!! (model, vi, DynamicPPL. SamplingContext (rng, spl)))
90
+ vi = last (DynamicPPL. evaluate!! (ldf . model, vi, DynamicPPL. SamplingContext (rng, spl)))
104
91
105
92
# Compute next sample and state.
106
- sample = Transition (model, vi)
107
- newstate = SGHMCState (ℓ, vi, newv)
93
+ sample = Transition (ldf . model, vi)
94
+ newstate = SGHMCState (vi, newv)
108
95
109
96
return sample, newstate
110
97
end
@@ -208,57 +195,45 @@ metadata(t::SGLDTransition) = (lp=t.lp, SGLD_stepsize=t.stepsize)
208
195
209
196
DynamicPPL. getlogp (t:: SGLDTransition ) = t. lp
210
197
211
- struct SGLDState{L,V<: AbstractVarInfo }
212
- logdensity:: L
198
+ struct SGLDState{V<: AbstractVarInfo }
213
199
vi:: V
214
200
step:: Int
215
201
end
216
202
217
- function DynamicPPL . initialstep (
203
+ function AbstractMCMC . step (
218
204
rng:: Random.AbstractRNG ,
219
- model:: Model ,
220
- spl:: Sampler{<:SGLD} ,
221
- vi:: AbstractVarInfo ;
205
+ ldf:: DynamicPPL.LogDensityFunction ,
206
+ spl:: Sampler{<:SGLD} ;
222
207
kwargs... ,
223
208
)
224
- # Transform the samples to unconstrained space and compute the joint log probability.
225
- if ! DynamicPPL. islinked (vi)
226
- vi = DynamicPPL. link!! (vi, model)
227
- vi = last (DynamicPPL. evaluate!! (model, vi, DynamicPPL. SamplingContext (rng, spl)))
228
- end
229
-
230
209
# Create first sample and state.
231
- sample = SGLDTransition (model, vi, zero (spl. alg. stepsize (0 )))
232
- ℓ = DynamicPPL. LogDensityFunction (
233
- model,
234
- vi,
235
- DynamicPPL. SamplingContext (spl, DynamicPPL. DefaultContext ());
236
- adtype= spl. alg. adtype,
237
- )
238
- state = SGLDState (ℓ, vi, 1 )
239
-
210
+ vi = ldf. varinfo
211
+ sample = SGLDTransition (ldf. model, vi, zero (spl. alg. stepsize (0 )))
212
+ state = SGLDState (vi, 1 )
240
213
return sample, state
241
214
end
242
215
243
216
function AbstractMCMC. step (
244
- rng:: Random.AbstractRNG , model:: Model , spl:: Sampler{<:SGLD} , state:: SGLDState ; kwargs...
217
+ rng:: Random.AbstractRNG ,
218
+ ldf:: LogDensityFunction ,
219
+ spl:: Sampler{<:SGLD} ,
220
+ state:: SGLDState ;
221
+ kwargs... ,
245
222
)
246
223
# Perform gradient step.
247
- ℓ = state. logdensity
248
224
vi = state. vi
249
225
θ = vi[:]
250
- grad = last (LogDensityProblems. logdensity_and_gradient (ℓ , θ))
226
+ grad = last (LogDensityProblems. logdensity_and_gradient (ldf , θ))
251
227
step = state. step
252
228
stepsize = spl. alg. stepsize (step)
253
229
θ .+ = (stepsize / 2 ) .* grad .+ sqrt (stepsize) .* randn (rng, eltype (θ), length (θ))
254
230
255
231
# Save new variables and recompute log density.
256
232
vi = DynamicPPL. unflatten (vi, θ)
257
- vi = last (DynamicPPL. evaluate!! (model, vi, DynamicPPL. SamplingContext (rng, spl)))
233
+ vi = last (DynamicPPL. evaluate!! (ldf . model, vi, DynamicPPL. SamplingContext (rng, spl)))
258
234
259
235
# Compute next sample and state.
260
- sample = SGLDTransition (model, vi, stepsize)
261
- newstate = SGLDState (ℓ, vi, state. step + 1 )
262
-
236
+ sample = SGLDTransition (ldf. model, vi, stepsize)
237
+ newstate = SGLDState (vi, state. step + 1 )
263
238
return sample, newstate
264
239
end
0 commit comments