Skip to content

Commit 6ecc146

Browse files
committed
Fix strictness failure with DifferentiationInterface 0.7
1 parent bfa0a9c commit 6ecc146

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

src/logdensityfunction.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ struct LogDensityFunction{
106106
adtype::AD
107107
"(internal use only) gradient preparation object for the model"
108108
prep::Union{Nothing,DI.GradientPrep}
109+
"(internal use only) the closure used for the gradient preparation"
110+
closure::Union{Nothing,Function}
109111

110112
function LogDensityFunction(
111113
model::Model,
@@ -114,6 +116,7 @@ struct LogDensityFunction{
114116
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
115117
)
116118
if adtype === nothing
119+
closure = nothing
117120
prep = nothing
118121
else
119122
# Make backend-specific tweaks to the adtype
@@ -124,10 +127,16 @@ struct LogDensityFunction{
124127
# Get a set of dummy params to use for prep
125128
x = map(identity, varinfo[:])
126129
if use_closure(adtype)
127-
prep = DI.prepare_gradient(
128-
x -> logdensity_at(x, model, varinfo, context), adtype, x
129-
)
130+
# The closure itself has to be stored inside the
131+
# LogDensityFunction to ensure that the signature of the
132+
# function being differentiated is the same as that used for
133+
# preparation. See
134+
# https://github.com/TuringLang/DynamicPPL.jl/pull/922 for an
135+
# explanation.
136+
closure = x -> logdensity_at(x, model, varinfo, context)
137+
prep = DI.prepare_gradient(closure, adtype, x)
130138
else
139+
closure = nothing
131140
prep = DI.prepare_gradient(
132141
logdensity_at,
133142
adtype,
@@ -139,7 +148,7 @@ struct LogDensityFunction{
139148
end
140149
end
141150
return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}(
142-
model, varinfo, context, adtype, prep
151+
model, varinfo, context, adtype, prep, closure
143152
)
144153
end
145154
end
@@ -208,9 +217,8 @@ function LogDensityProblems.logdensity_and_gradient(
208217
# Make branching statically inferrable, i.e. type-stable (even if the two
209218
# branches happen to return different types)
210219
return if use_closure(f.adtype)
211-
DI.value_and_gradient(
212-
x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x
213-
)
220+
f.closure === nothing && error("Closure not available; this should not happen")
221+
DI.value_and_gradient(f.closure, f.prep, f.adtype, x)
214222
else
215223
DI.value_and_gradient(
216224
logdensity_at,

0 commit comments

Comments
 (0)