@@ -4,7 +4,7 @@ using ADTypes: AbstractADType, AutoForwardDiff
4
4
using Chairmarks: @be
5
5
import DifferentiationInterface as DI
6
6
using DocStringExtensions
7
- using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link
7
+ using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint, link
8
8
using LogDensityProblems: logdensity, logdensity_and_gradient
9
9
using Random: Random, Xoshiro
10
10
using Statistics: median
@@ -184,7 +184,7 @@ function run_ad(
184
184
185
185
verbose && @info " Running AD on $(model. f) with $(adtype) \n "
186
186
verbose && println (" params : $(params) " )
187
- ldf = LogDensityFunction (model, varinfo; adtype= adtype)
187
+ ldf = LogDensityFunction (model, getlogjoint, varinfo; adtype= adtype)
188
188
189
189
value, grad = logdensity_and_gradient (ldf, params)
190
190
grad = collect (grad)
@@ -193,7 +193,9 @@ function run_ad(
193
193
if test
194
194
# Calculate ground truth to compare against
195
195
value_true, grad_true = if expected_value_and_grad === nothing
196
- ldf_reference = LogDensityFunction (model, varinfo; adtype= reference_adtype)
196
+ ldf_reference = LogDensityFunction (
197
+ model, getlogjoint, varinfo; adtype= reference_adtype
198
+ )
197
199
logdensity_and_gradient (ldf_reference, params)
198
200
else
199
201
expected_value_and_grad
0 commit comments