From 1f1ec85c66e824fb2be3c5cfc842ccde8c67b8e7 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 20 May 2025 17:38:14 +0100 Subject: [PATCH 1/5] Give LogDensityFunction the getlogdensity field --- benchmarks/src/DynamicPPLBenchmarks.jl | 2 +- src/logdensityfunction.jl | 96 ++++++++++++++++---------- test/ad.jl | 6 +- test/logdensityfunction.jl | 11 ++- test/test_util.jl | 2 +- 5 files changed, 74 insertions(+), 43 deletions(-) diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 16338de2f..602443194 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -87,7 +87,7 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked:: vi = DynamicPPL.link(vi, model) end - f = DynamicPPL.LogDensityFunction(model, vi, context; adtype=adbackend) + f = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint, vi, context; adtype=adbackend) # The parameters at which we evaluate f. θ = vi[:] diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 1b5e9b8c4..9e3d45f2b 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -17,7 +17,8 @@ is_supported(::ADTypes.AutoReverseDiff) = true """ LogDensityFunction( model::Model, - varinfo::AbstractVarInfo=VarInfo(model), + getlogdensity::Function=getlogjoint, + varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity), context::AbstractContext=DefaultContext(); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing ) @@ -28,10 +29,10 @@ A struct which contains a model, along with all the information necessary to: - and if `adtype` is provided, calculate the gradient of the log density at that point. -At its most basic level, a LogDensityFunction wraps the model together with its -the type of varinfo to be used, as well as the evaluation context. These must -be known in order to calculate the log density (using -[`DynamicPPL.evaluate!!`](@ref)). +At its most basic level, a LogDensityFunction wraps the model together with +the type of varinfo to be used, as well as the evaluation context and a function +to extract the log density from the VarInfo. These must be known in order to +calculate the log density (using [`DynamicPPL.evaluate!!`](@ref)). If the `adtype` keyword argument is provided, then this struct will also store the adtype along with other information for efficient calculation of the @@ -73,13 +74,13 @@ julia> LogDensityProblems.dimension(f) 1 julia> # By default it uses `VarInfo` under the hood, but this is not necessary. - f = LogDensityFunction(model, SimpleVarInfo(model)); + f = LogDensityFunction(model, getlogjoint, SimpleVarInfo(model)); julia> LogDensityProblems.logdensity(f, [0.0]) -2.3378770664093453 -julia> # LogDensityFunction respects the accumulators in VarInfo: - f_prior = LogDensityFunction(model, setaccs!!(VarInfo(model), (LogPriorAccumulator(),))); +julia> # One can also specify evaluating e.g. the log prior only: + f_prior = LogDensityFunction(model, getprior); julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) true @@ -94,11 +95,13 @@ julia> LogDensityProblems.logdensity_and_gradient(f, [0.0]) ``` """ struct LogDensityFunction{ - M<:Model,V<:AbstractVarInfo,C<:AbstractContext,AD<:Union{Nothing,ADTypes.AbstractADType} + M<:Model,F<:Function,V<:AbstractVarInfo,C<:AbstractContext,AD<:Union{Nothing,ADTypes.AbstractADType} } "model used for evaluation" model::M - "varinfo used for evaluation" + "function to be called on `varinfo` to extract the log density. By default `getlogjoint`." + getlogdensity::F + "varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`." varinfo::V "context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable" context::C @@ -109,7 +112,8 @@ struct LogDensityFunction{ function LogDensityFunction( model::Model, - varinfo::AbstractVarInfo=VarInfo(model), + getlogdensity::Function=getlogjoint, + varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity), context::AbstractContext=leafcontext(model.context); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) @@ -125,7 +129,7 @@ struct LogDensityFunction{ x = map(identity, varinfo[:]) if use_closure(adtype) prep = DI.prepare_gradient( - x -> logdensity_at(x, model, varinfo, context), adtype, x + x -> logdensity_at(x, model, getlogdensity, varinfo, context), adtype, x ) else prep = DI.prepare_gradient( @@ -133,13 +137,14 @@ struct LogDensityFunction{ adtype, x, DI.Constant(model), + DI.Constant(getlogdensity), DI.Constant(varinfo), DI.Constant(context), ) end end - return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}( - model, varinfo, context, adtype, prep + return new{typeof(model),typeof(getlogdensity),typeof(varinfo),typeof(context),typeof(adtype)}( + model, getlogdensity, varinfo, context, adtype, prep ) end end @@ -164,10 +169,36 @@ function LogDensityFunction( end end +""" + ldf_default_varinfo(model::Model, getlogdensity::Function) + +Create the default AbstractVarInfo that should be used for evaluating the log density. + +Only the accumulators necesessary for `getlogdensity` will be used. +""" +function ldf_default_varinfo(::Model, getlogdensity::Function) + msg = """ + LogDensityFunction does not know what sort of VarInfo should be used when \ + `getlogdensity` is $getlogdensity. Please specify a VarInfo explicitly. + """ + error(msg) + end + +ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) = VarInfo(model) + +function ldf_default_varinfo(model::Model, ::typeof(getlogprior)) + return setaccs!!(VarInfo(model), (LogPriorAccumulator(),)) +end + +function ldf_default_varinfo(model::Model, ::typeof(getloglikelihood)) + return setaccs!!(VarInfo(model), (LogLikelihoodAccumulator(),)) +end + """ logdensity_at( x::AbstractVector, model::Model, + getlogdensity::Function, varinfo::AbstractVarInfo, context::AbstractContext ) @@ -175,45 +206,35 @@ end Evaluate the log density of the given `model` at the given parameter values `x`, using the given `varinfo` and `context`. Note that the `varinfo` argument is provided only for its structure, in the sense that the parameters from the vector `x` are inserted -into it, and its own parameters are discarded. It does, however, determine whether the log -prior, likelihood, or joint is returned, based on which accumulators are set in it. +into it, and its own parameters are discarded. `getlogdensity` is the function that extracts +the log density from the evaluated varinfo. """ function logdensity_at( - x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext + x::AbstractVector, model::Model, getlogdensity::Function, varinfo::AbstractVarInfo, context::AbstractContext ) varinfo_new = unflatten(varinfo, x) varinfo_eval = last(evaluate!!(model, varinfo_new, context)) - has_prior = hasacc(varinfo_eval, Val(:LogPrior)) - has_likelihood = hasacc(varinfo_eval, Val(:LogLikelihood)) - if has_prior && has_likelihood - return getlogjoint(varinfo_eval) - elseif has_prior - return getlogprior(varinfo_eval) - elseif has_likelihood - return getloglikelihood(varinfo_eval) - else - error("LogDensityFunction: varinfo tracks neither log prior nor log likelihood") - end + return getlogdensity(varinfo_eval) end ### LogDensityProblems interface function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,V,C,Nothing}} -) where {M,V,C} + ::Type{<:LogDensityFunction{M,F,V,C,Nothing}} +) where {M,F,V,C} return LogDensityProblems.LogDensityOrder{0}() end function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,V,C,AD}} -) where {M,V,C,AD<:ADTypes.AbstractADType} + ::Type{<:LogDensityFunction{M,F,V,C,AD}} +) where {M,F,V,C,AD<:ADTypes.AbstractADType} return LogDensityProblems.LogDensityOrder{1}() end function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector) - return logdensity_at(x, f.model, f.varinfo, f.context) + return logdensity_at(x, f.model, f.getlogdensity, f.varinfo, f.context) end function LogDensityProblems.logdensity_and_gradient( - f::LogDensityFunction{M,V,C,AD}, x::AbstractVector -) where {M,V,C,AD<:ADTypes.AbstractADType} + f::LogDensityFunction{M,F,V,C,AD}, x::AbstractVector +) where {M,F,V,C,AD<:ADTypes.AbstractADType} f.prep === nothing && error("Gradient preparation not available; this should not happen") x = map(identity, x) # Concretise type @@ -221,7 +242,7 @@ function LogDensityProblems.logdensity_and_gradient( # branches happen to return different types) return if use_closure(f.adtype) DI.value_and_gradient( - x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x + x -> logdensity_at(x, f.model, f.getlogdensity, f.varinfo, f.context), f.prep, f.adtype, x ) else DI.value_and_gradient( @@ -230,6 +251,7 @@ function LogDensityProblems.logdensity_and_gradient( f.adtype, x, DI.Constant(f.model), + DI.Constant(f.getlogdensity), DI.Constant(f.varinfo), DI.Constant(f.context), ) @@ -304,7 +326,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. """ function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) - return LogDensityFunction(model, f.varinfo, f.context; adtype=f.adtype) + return LogDensityFunction(model, f.getlogdensity, f.varinfo, f.context; adtype=f.adtype) end """ diff --git a/test/ad.jl b/test/ad.jl index 69ab99e19..afc93e9dd 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -24,10 +24,10 @@ using DynamicPPL: LogDensityFunction @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos linked_varinfo = DynamicPPL.link(varinfo, m) - f = LogDensityFunction(m, linked_varinfo) + f = LogDensityFunction(m, getlogjoint, linked_varinfo) x = DynamicPPL.getparams(f) # Calculate reference logp + gradient of logp using ForwardDiff - ref_ldf = LogDensityFunction(m, linked_varinfo; adtype=ref_adtype) + ref_ldf = LogDensityFunction(m, getlogjoint, linked_varinfo; adtype=ref_adtype) ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x) @testset "$adtype" for adtype in test_adtypes @@ -106,7 +106,7 @@ using DynamicPPL: LogDensityFunction spl = Sampler(MyEmptyAlg()) vi = VarInfo(model) ldf = LogDensityFunction( - model, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true) + model, getlogjoint, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true) ) @test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index d6e66ec59..2504192a5 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -15,8 +15,17 @@ end vns = DynamicPPL.TestUtils.varnames(model) varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) + vi = first(varinfos) + theta = vi[:] + ldf_joint = DynamicPPL.LogDensityFunction(model) + @test LogDensityProblems.logdensity(ldf_joint, theta) ≈ logjoint(model, vi) + ldf_prior = DynamicPPL.LogDensityFunction(model, getlogprior) + @test LogDensityProblems.logdensity(ldf_prior, theta) ≈ logprior(model, vi) + ldf_likelihood = DynamicPPL.LogDensityFunction(model, getloglikelihood) + @test LogDensityProblems.logdensity(ldf_likelihood, theta) ≈ loglikelihood(model, vi) + @testset "$(varinfo)" for varinfo in varinfos - logdensity = DynamicPPL.LogDensityFunction(model, varinfo) + logdensity = DynamicPPL.LogDensityFunction(model, getlogjoint, varinfo) θ = varinfo[:] @test LogDensityProblems.logdensity(logdensity, θ) ≈ logjoint(model, varinfo) @test LogDensityProblems.dimension(logdensity) == length(θ) diff --git a/test/test_util.jl b/test/test_util.jl index 902dd7230..163a63f4e 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -14,7 +14,7 @@ function test_model_ad(model, logp_manual) x = vi[:] # Log probabilities using the model. - ℓ = DynamicPPL.LogDensityFunction(model, vi) + ℓ = DynamicPPL.LogDensityFunction(model, getlogjoint, vi) logp_model = Base.Fix1(LogDensityProblems.logdensity, ℓ) # Check that both functions return the same values. From e7077ba314b3d2efd7a1616ad0ea068cbd7cc53e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 20 May 2025 17:40:49 +0100 Subject: [PATCH 2/5] Allow missing LogPriorAccumulator when linking --- src/simple_varinfo.jl | 8 ++++++-- src/varinfo.jl | 20 +++++++++++++++----- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 42fcedfb8..151b4fca7 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -606,7 +606,9 @@ function link!!( x = vi.values y, logjac = with_logabsdet_jacobian(b, x) vi_new = Accessors.@set(vi.values = y) - vi_new = acclogprior!!(vi_new, -logjac) + if hasacc(vi_new, Val(:LogPrior)) + vi_new = acclogprior!!(vi_new, -logjac) + end return settrans!!(vi_new, t) end @@ -619,7 +621,9 @@ function invlink!!( y = vi.values x, logjac = with_logabsdet_jacobian(b, y) vi_new = Accessors.@set(vi.values = x) - vi_new = acclogprior!!(vi_new, logjac) + if hasacc(vi_new, Val(:LogPrior)) + vi_new = acclogprior!!(vi_new, logjac) + end return settrans!!(vi_new, NoTransformation()) end diff --git a/src/varinfo.jl b/src/varinfo.jl index 6a968da4d..85ea0757b 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1241,7 +1241,9 @@ function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) setrange!(md, vn, start:(start + length(yvec) - 1)) # Set the new value. setval!(md, yvec, vn) - vi = acclogprior!!(vi, -logjac) + if hasacc(vi, Val(:LogPrior)) + vi = acclogprior!!(vi, -logjac) + end return vi end @@ -1278,7 +1280,9 @@ function _link(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) md, logjac = _link_metadata!!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) - new_varinfo = acclogprior!!(new_varinfo, -logjac) + if hasacc(new_varinfo, Val(:LogPrior)) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + end return new_varinfo end @@ -1292,7 +1296,9 @@ function _link(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) md, logjac = _link_metadata!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) - new_varinfo = acclogprior!!(new_varinfo, -logjac) + if hasacc(new_varinfo, Val(:LogPrior)) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + end return new_varinfo end @@ -1441,7 +1447,9 @@ function _invlink(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) md, logjac = _invlink_metadata!!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) - new_varinfo = acclogprior!!(new_varinfo, -logjac) + if hasacc(new_varinfo, Val(:LogPrior)) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + end return new_varinfo end @@ -1455,7 +1463,9 @@ function _invlink(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) md, logjac = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) - new_varinfo = acclogprior!!(new_varinfo, -logjac) + if hasacc(new_varinfo, Val(:LogPrior)) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + end return new_varinfo end From 216789cb84f297b788afd88bf12392b29afbace5 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 20 May 2025 17:44:36 +0100 Subject: [PATCH 3/5] Trim whitespace --- src/varname.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varname.jl b/src/varname.jl index c16587065..3eb1f2460 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -7,7 +7,7 @@ This is a very restricted version `subumes(u::VarName, v::VarName)` only really - Scalar: `x` subsumes `x[1, 2]`, `x[1, 2]` subsumes `x[1, 2][3]`, etc. ## Note -- To get same matching capabilities as `AbstractPPL.subumes(u::VarName, v::VarName)` +- To get same matching capabilities as `AbstractPPL.subumes(u::VarName, v::VarName)` for strings, one can always do `eval(varname(Meta.parse(u))` to get `VarName` of `u`, and similarly to `v`. But this is slow. """ From 59e22a2b39b4ab0a822b9db1c561f62db8476c66 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 20 May 2025 17:53:43 +0100 Subject: [PATCH 4/5] Run formatter --- benchmarks/src/DynamicPPLBenchmarks.jl | 4 ++- src/logdensityfunction.jl | 37 +++++++++++++++++++------- test/ad.jl | 10 +++++-- test/logdensityfunction.jl | 3 ++- 4 files changed, 40 insertions(+), 14 deletions(-) diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 602443194..3707356b6 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -87,7 +87,9 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked:: vi = DynamicPPL.link(vi, model) end - f = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint, vi, context; adtype=adbackend) + f = DynamicPPL.LogDensityFunction( + model, DynamicPPL.getlogjoint, vi, context; adtype=adbackend + ) # The parameters at which we evaluate f. θ = vi[:] diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 9e3d45f2b..6ce303c00 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -95,7 +95,11 @@ julia> LogDensityProblems.logdensity_and_gradient(f, [0.0]) ``` """ struct LogDensityFunction{ - M<:Model,F<:Function,V<:AbstractVarInfo,C<:AbstractContext,AD<:Union{Nothing,ADTypes.AbstractADType} + M<:Model, + F<:Function, + V<:AbstractVarInfo, + C<:AbstractContext, + AD<:Union{Nothing,ADTypes.AbstractADType}, } "model used for evaluation" model::M @@ -143,7 +147,13 @@ struct LogDensityFunction{ ) end end - return new{typeof(model),typeof(getlogdensity),typeof(varinfo),typeof(context),typeof(adtype)}( + return new{ + typeof(model), + typeof(getlogdensity), + typeof(varinfo), + typeof(context), + typeof(adtype), + }( model, getlogdensity, varinfo, context, adtype, prep ) end @@ -177,12 +187,12 @@ Create the default AbstractVarInfo that should be used for evaluating the log de Only the accumulators necesessary for `getlogdensity` will be used. """ function ldf_default_varinfo(::Model, getlogdensity::Function) - msg = """ - LogDensityFunction does not know what sort of VarInfo should be used when \ - `getlogdensity` is $getlogdensity. Please specify a VarInfo explicitly. - """ - error(msg) - end + msg = """ + LogDensityFunction does not know what sort of VarInfo should be used when \ + `getlogdensity` is $getlogdensity. Please specify a VarInfo explicitly. + """ + return error(msg) +end ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) = VarInfo(model) @@ -210,7 +220,11 @@ into it, and its own parameters are discarded. `getlogdensity` is the function t the log density from the evaluated varinfo. """ function logdensity_at( - x::AbstractVector, model::Model, getlogdensity::Function, varinfo::AbstractVarInfo, context::AbstractContext + x::AbstractVector, + model::Model, + getlogdensity::Function, + varinfo::AbstractVarInfo, + context::AbstractContext, ) varinfo_new = unflatten(varinfo, x) varinfo_eval = last(evaluate!!(model, varinfo_new, context)) @@ -242,7 +256,10 @@ function LogDensityProblems.logdensity_and_gradient( # branches happen to return different types) return if use_closure(f.adtype) DI.value_and_gradient( - x -> logdensity_at(x, f.model, f.getlogdensity, f.varinfo, f.context), f.prep, f.adtype, x + x -> logdensity_at(x, f.model, f.getlogdensity, f.varinfo, f.context), + f.prep, + f.adtype, + x, ) else DI.value_and_gradient( diff --git a/test/ad.jl b/test/ad.jl index afc93e9dd..1fcf2810e 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -27,7 +27,9 @@ using DynamicPPL: LogDensityFunction f = LogDensityFunction(m, getlogjoint, linked_varinfo) x = DynamicPPL.getparams(f) # Calculate reference logp + gradient of logp using ForwardDiff - ref_ldf = LogDensityFunction(m, getlogjoint, linked_varinfo; adtype=ref_adtype) + ref_ldf = LogDensityFunction( + m, getlogjoint, linked_varinfo; adtype=ref_adtype + ) ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x) @testset "$adtype" for adtype in test_adtypes @@ -106,7 +108,11 @@ using DynamicPPL: LogDensityFunction spl = Sampler(MyEmptyAlg()) vi = VarInfo(model) ldf = LogDensityFunction( - model, getlogjoint, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true) + model, + getlogjoint, + vi, + SamplingContext(spl); + adtype=AutoReverseDiff(; compile=true), ) @test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 2504192a5..c4d0d6beb 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -22,7 +22,8 @@ end ldf_prior = DynamicPPL.LogDensityFunction(model, getlogprior) @test LogDensityProblems.logdensity(ldf_prior, theta) ≈ logprior(model, vi) ldf_likelihood = DynamicPPL.LogDensityFunction(model, getloglikelihood) - @test LogDensityProblems.logdensity(ldf_likelihood, theta) ≈ loglikelihood(model, vi) + @test LogDensityProblems.logdensity(ldf_likelihood, theta) ≈ + loglikelihood(model, vi) @testset "$(varinfo)" for varinfo in varinfos logdensity = DynamicPPL.LogDensityFunction(model, getlogjoint, varinfo) From 175b6334f1b3a8d5af34785b59c1cf35df694c67 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 21 May 2025 17:50:51 +0100 Subject: [PATCH 5/5] Fix a few typos --- src/logdensityfunction.jl | 2 +- src/test_utils/ad.jl | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 6ce303c00..0ffd959bb 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -80,7 +80,7 @@ julia> LogDensityProblems.logdensity(f, [0.0]) -2.3378770664093453 julia> # One can also specify evaluating e.g. the log prior only: - f_prior = LogDensityFunction(model, getprior); + f_prior = LogDensityFunction(model, getlogprior); julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) true diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index d38915c12..85a781be7 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -4,7 +4,7 @@ using ADTypes: AbstractADType, AutoForwardDiff using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions -using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link +using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint, link using LogDensityProblems: logdensity, logdensity_and_gradient using Random: Random, Xoshiro using Statistics: median @@ -184,7 +184,7 @@ function run_ad( verbose && @info "Running AD on $(model.f) with $(adtype)\n" verbose && println(" params : $(params)") - ldf = LogDensityFunction(model, varinfo; adtype=adtype) + ldf = LogDensityFunction(model, getlogjoint, varinfo; adtype=adtype) value, grad = logdensity_and_gradient(ldf, params) grad = collect(grad) @@ -193,7 +193,9 @@ function run_ad( if test # Calculate ground truth to compare against value_true, grad_true = if expected_value_and_grad === nothing - ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype) + ldf_reference = LogDensityFunction( + model, getlogjoint, varinfo; adtype=reference_adtype + ) logdensity_and_gradient(ldf_reference, params) else expected_value_and_grad