@@ -95,7 +95,11 @@ julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
95
95
```
96
96
"""
97
97
struct LogDensityFunction{
98
- M<: Model ,F<: Function ,V<: AbstractVarInfo ,C<: AbstractContext ,AD<: Union{Nothing,ADTypes.AbstractADType}
98
+ M<: Model ,
99
+ F<: Function ,
100
+ V<: AbstractVarInfo ,
101
+ C<: AbstractContext ,
102
+ AD<: Union{Nothing,ADTypes.AbstractADType} ,
99
103
}
100
104
" model used for evaluation"
101
105
model:: M
@@ -143,7 +147,13 @@ struct LogDensityFunction{
143
147
)
144
148
end
145
149
end
146
- return new {typeof(model),typeof(getlogdensity),typeof(varinfo),typeof(context),typeof(adtype)} (
150
+ return new{
151
+ typeof (model),
152
+ typeof (getlogdensity),
153
+ typeof (varinfo),
154
+ typeof (context),
155
+ typeof (adtype),
156
+ }(
147
157
model, getlogdensity, varinfo, context, adtype, prep
148
158
)
149
159
end
@@ -177,12 +187,12 @@ Create the default AbstractVarInfo that should be used for evaluating the log de
177
187
Only the accumulators necesessary for `getlogdensity` will be used.
178
188
"""
179
189
function ldf_default_varinfo (:: Model , getlogdensity:: Function )
180
- msg = """
181
- LogDensityFunction does not know what sort of VarInfo should be used when \
182
- `getlogdensity` is $getlogdensity . Please specify a VarInfo explicitly.
183
- """
184
- error (msg)
185
- end
190
+ msg = """
191
+ LogDensityFunction does not know what sort of VarInfo should be used when \
192
+ `getlogdensity` is $getlogdensity . Please specify a VarInfo explicitly.
193
+ """
194
+ return error (msg)
195
+ end
186
196
187
197
ldf_default_varinfo (model:: Model , :: typeof (getlogjoint)) = VarInfo (model)
188
198
@@ -210,7 +220,11 @@ into it, and its own parameters are discarded. `getlogdensity` is the function t
210
220
the log density from the evaluated varinfo.
211
221
"""
212
222
function logdensity_at (
213
- x:: AbstractVector , model:: Model , getlogdensity:: Function , varinfo:: AbstractVarInfo , context:: AbstractContext
223
+ x:: AbstractVector ,
224
+ model:: Model ,
225
+ getlogdensity:: Function ,
226
+ varinfo:: AbstractVarInfo ,
227
+ context:: AbstractContext ,
214
228
)
215
229
varinfo_new = unflatten (varinfo, x)
216
230
varinfo_eval = last (evaluate!! (model, varinfo_new, context))
@@ -242,7 +256,10 @@ function LogDensityProblems.logdensity_and_gradient(
242
256
# branches happen to return different types)
243
257
return if use_closure (f. adtype)
244
258
DI. value_and_gradient (
245
- x -> logdensity_at (x, f. model, f. getlogdensity, f. varinfo, f. context), f. prep, f. adtype, x
259
+ x -> logdensity_at (x, f. model, f. getlogdensity, f. varinfo, f. context),
260
+ f. prep,
261
+ f. adtype,
262
+ x,
246
263
)
247
264
else
248
265
DI. value_and_gradient (
0 commit comments