Skip to content

Commit 59e22a2

Browse files
committed
Run formatter
1 parent 216789c commit 59e22a2

File tree

4 files changed

+40
-14
lines changed

4 files changed

+40
-14
lines changed

benchmarks/src/DynamicPPLBenchmarks.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
8787
vi = DynamicPPL.link(vi, model)
8888
end
8989

90-
f = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint, vi, context; adtype=adbackend)
90+
f = DynamicPPL.LogDensityFunction(
91+
model, DynamicPPL.getlogjoint, vi, context; adtype=adbackend
92+
)
9193
# The parameters at which we evaluate f.
9294
θ = vi[:]
9395

src/logdensityfunction.jl

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,11 @@ julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
9595
```
9696
"""
9797
struct LogDensityFunction{
98-
M<:Model,F<:Function,V<:AbstractVarInfo,C<:AbstractContext,AD<:Union{Nothing,ADTypes.AbstractADType}
98+
M<:Model,
99+
F<:Function,
100+
V<:AbstractVarInfo,
101+
C<:AbstractContext,
102+
AD<:Union{Nothing,ADTypes.AbstractADType},
99103
}
100104
"model used for evaluation"
101105
model::M
@@ -143,7 +147,13 @@ struct LogDensityFunction{
143147
)
144148
end
145149
end
146-
return new{typeof(model),typeof(getlogdensity),typeof(varinfo),typeof(context),typeof(adtype)}(
150+
return new{
151+
typeof(model),
152+
typeof(getlogdensity),
153+
typeof(varinfo),
154+
typeof(context),
155+
typeof(adtype),
156+
}(
147157
model, getlogdensity, varinfo, context, adtype, prep
148158
)
149159
end
@@ -177,12 +187,12 @@ Create the default AbstractVarInfo that should be used for evaluating the log de
177187
Only the accumulators necesessary for `getlogdensity` will be used.
178188
"""
179189
function ldf_default_varinfo(::Model, getlogdensity::Function)
180-
msg = """
181-
LogDensityFunction does not know what sort of VarInfo should be used when \
182-
`getlogdensity` is $getlogdensity. Please specify a VarInfo explicitly.
183-
"""
184-
error(msg)
185-
end
190+
msg = """
191+
LogDensityFunction does not know what sort of VarInfo should be used when \
192+
`getlogdensity` is $getlogdensity. Please specify a VarInfo explicitly.
193+
"""
194+
return error(msg)
195+
end
186196

187197
ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) = VarInfo(model)
188198

@@ -210,7 +220,11 @@ into it, and its own parameters are discarded. `getlogdensity` is the function t
210220
the log density from the evaluated varinfo.
211221
"""
212222
function logdensity_at(
213-
x::AbstractVector, model::Model, getlogdensity::Function, varinfo::AbstractVarInfo, context::AbstractContext
223+
x::AbstractVector,
224+
model::Model,
225+
getlogdensity::Function,
226+
varinfo::AbstractVarInfo,
227+
context::AbstractContext,
214228
)
215229
varinfo_new = unflatten(varinfo, x)
216230
varinfo_eval = last(evaluate!!(model, varinfo_new, context))
@@ -242,7 +256,10 @@ function LogDensityProblems.logdensity_and_gradient(
242256
# branches happen to return different types)
243257
return if use_closure(f.adtype)
244258
DI.value_and_gradient(
245-
x -> logdensity_at(x, f.model, f.getlogdensity, f.varinfo, f.context), f.prep, f.adtype, x
259+
x -> logdensity_at(x, f.model, f.getlogdensity, f.varinfo, f.context),
260+
f.prep,
261+
f.adtype,
262+
x,
246263
)
247264
else
248265
DI.value_and_gradient(

test/ad.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ using DynamicPPL: LogDensityFunction
2727
f = LogDensityFunction(m, getlogjoint, linked_varinfo)
2828
x = DynamicPPL.getparams(f)
2929
# Calculate reference logp + gradient of logp using ForwardDiff
30-
ref_ldf = LogDensityFunction(m, getlogjoint, linked_varinfo; adtype=ref_adtype)
30+
ref_ldf = LogDensityFunction(
31+
m, getlogjoint, linked_varinfo; adtype=ref_adtype
32+
)
3133
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)
3234

3335
@testset "$adtype" for adtype in test_adtypes
@@ -106,7 +108,11 @@ using DynamicPPL: LogDensityFunction
106108
spl = Sampler(MyEmptyAlg())
107109
vi = VarInfo(model)
108110
ldf = LogDensityFunction(
109-
model, getlogjoint, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true)
111+
model,
112+
getlogjoint,
113+
vi,
114+
SamplingContext(spl);
115+
adtype=AutoReverseDiff(; compile=true),
110116
)
111117
@test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any
112118
end

test/logdensityfunction.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ end
2222
ldf_prior = DynamicPPL.LogDensityFunction(model, getlogprior)
2323
@test LogDensityProblems.logdensity(ldf_prior, theta) logprior(model, vi)
2424
ldf_likelihood = DynamicPPL.LogDensityFunction(model, getloglikelihood)
25-
@test LogDensityProblems.logdensity(ldf_likelihood, theta) loglikelihood(model, vi)
25+
@test LogDensityProblems.logdensity(ldf_likelihood, theta)
26+
loglikelihood(model, vi)
2627

2728
@testset "$(varinfo)" for varinfo in varinfos
2829
logdensity = DynamicPPL.LogDensityFunction(model, getlogjoint, varinfo)

0 commit comments

Comments
 (0)