@@ -106,6 +106,8 @@ struct LogDensityFunction{
106
106
adtype:: AD
107
107
" (internal use only) gradient preparation object for the model"
108
108
prep:: Union{Nothing,DI.GradientPrep}
109
+ " (internal use only) the closure used for the gradient preparation"
110
+ closure:: Union{Nothing,Function}
109
111
110
112
function LogDensityFunction (
111
113
model:: Model ,
@@ -114,6 +116,7 @@ struct LogDensityFunction{
114
116
adtype:: Union{ADTypes.AbstractADType,Nothing} = nothing ,
115
117
)
116
118
if adtype === nothing
119
+ closure = nothing
117
120
prep = nothing
118
121
else
119
122
# Make backend-specific tweaks to the adtype
@@ -124,10 +127,16 @@ struct LogDensityFunction{
124
127
# Get a set of dummy params to use for prep
125
128
x = map (identity, varinfo[:])
126
129
if use_closure (adtype)
127
- prep = DI. prepare_gradient (
128
- x -> logdensity_at (x, model, varinfo, context), adtype, x
129
- )
130
+ # The closure itself has to be stored inside the
131
+ # LogDensityFunction to ensure that the signature of the
132
+ # function being differentiated is the same as that used for
133
+ # preparation. See
134
+ # https://github.com/TuringLang/DynamicPPL.jl/pull/922 for an
135
+ # explanation.
136
+ closure = x -> logdensity_at (x, model, varinfo, context)
137
+ prep = DI. prepare_gradient (closure, adtype, x)
130
138
else
139
+ closure = nothing
131
140
prep = DI. prepare_gradient (
132
141
logdensity_at,
133
142
adtype,
@@ -139,7 +148,7 @@ struct LogDensityFunction{
139
148
end
140
149
end
141
150
return new {typeof(model),typeof(varinfo),typeof(context),typeof(adtype)} (
142
- model, varinfo, context, adtype, prep
151
+ model, varinfo, context, adtype, prep, closure
143
152
)
144
153
end
145
154
end
@@ -208,9 +217,8 @@ function LogDensityProblems.logdensity_and_gradient(
208
217
# Make branching statically inferrable, i.e. type-stable (even if the two
209
218
# branches happen to return different types)
210
219
return if use_closure (f. adtype)
211
- DI. value_and_gradient (
212
- x -> logdensity_at (x, f. model, f. varinfo, f. context), f. prep, f. adtype, x
213
- )
220
+ f. closure === nothing && error (" Closure not available; this should not happen" )
221
+ DI. value_and_gradient (f. closure, f. prep, f. adtype, x)
214
222
else
215
223
DI. value_and_gradient (
216
224
logdensity_at,
0 commit comments