From 6e32434922f29c7ee1402fbdc47e0529dbd0603c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 15 May 2025 16:13:06 +0100 Subject: [PATCH 1/2] First efforts towards DPPL 0.37 compat, WIP --- Project.toml | 2 +- ext/TuringOptimExt.jl | 33 +++---- src/mcmc/Inference.jl | 11 +-- src/mcmc/ess.jl | 20 ++-- src/mcmc/gibbs.jl | 12 ++- src/mcmc/hmc.jl | 4 - src/mcmc/is.jl | 4 - src/mcmc/mh.jl | 4 - src/mcmc/particle_mcmc.jl | 22 +++-- src/optimisation/Optimisation.jl | 147 +++++++++++++++--------------- test/Project.toml | 4 +- test/mcmc/Inference.jl | 45 --------- test/mcmc/hmc.jl | 2 + test/mcmc/mh.jl | 2 + test/optimisation/Optimisation.jl | 67 ++++---------- test/test_utils/ad_utils.jl | 22 ++--- 16 files changed, 147 insertions(+), 254 deletions(-) diff --git a/Project.toml b/Project.toml index 82d32ed82f..5c7aeebbfb 100644 --- a/Project.toml +++ b/Project.toml @@ -62,7 +62,7 @@ Distributions = "0.25.77" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.36" +DynamicPPL = "0.37" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3" Libtask = "0.8.8" diff --git a/ext/TuringOptimExt.jl b/ext/TuringOptimExt.jl index d6c253e2a2..9f5c51a2b4 100644 --- a/ext/TuringOptimExt.jl +++ b/ext/TuringOptimExt.jl @@ -34,8 +34,8 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - f = Optimisation.OptimLogDensity(model, ctx) + vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),)) + f = Optimisation.OptimLogDensity(model, vi) init_vals = DynamicPPL.getparams(f.ldf) optimizer = Optim.LBFGS() return _mle_optimize(model, init_vals, optimizer, options; kwargs...) @@ -57,8 +57,8 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - f = Optimisation.OptimLogDensity(model, ctx) + vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),)) + f = Optimisation.OptimLogDensity(model, vi) init_vals = DynamicPPL.getparams(f.ldf) return _mle_optimize(model, init_vals, optimizer, options; kwargs...) end @@ -74,8 +74,9 @@ function Optim.optimize( end function _mle_optimize(model::DynamicPPL.Model, args...; kwargs...) - ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...) + vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),)) + f = Optimisation.OptimLogDensity(model, vi) + return _optimize(f, args...; kwargs...) end """ @@ -104,8 +105,8 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - f = Optimisation.OptimLogDensity(model, ctx) + vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),)) + f = Optimisation.OptimLogDensity(model, vi) init_vals = DynamicPPL.getparams(f.ldf) optimizer = Optim.LBFGS() return _map_optimize(model, init_vals, optimizer, options; kwargs...) @@ -127,8 +128,8 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - f = Optimisation.OptimLogDensity(model, ctx) + vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),)) + f = Optimisation.OptimLogDensity(model, vi) init_vals = DynamicPPL.getparams(f.ldf) return _map_optimize(model, init_vals, optimizer, options; kwargs...) end @@ -144,9 +145,11 @@ function Optim.optimize( end function _map_optimize(model::DynamicPPL.Model, args...; kwargs...) - ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...) + vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),)) + f = Optimisation.OptimLogDensity(model, vi) + return _optimize(f, args...; kwargs...) end + """ _optimize(f::OptimLogDensity, optimizer=Optim.LBFGS(), args...; kwargs...) @@ -166,7 +169,7 @@ function _optimize( # whether initialisation is really necessary at all vi = DynamicPPL.unflatten(f.ldf.varinfo, init_vals) vi = DynamicPPL.link(vi, f.ldf.model) - f = Optimisation.OptimLogDensity(f.ldf.model, vi, f.ldf.context; adtype=f.ldf.adtype) + f = Optimisation.OptimLogDensity(f.ldf.model, vi; adtype=f.ldf.adtype) init_vals = DynamicPPL.getparams(f.ldf) # Optimize! @@ -183,9 +186,7 @@ function _optimize( # Get the optimum in unconstrained space. `getparams` does the invlinking. vi = f.ldf.varinfo vi_optimum = DynamicPPL.unflatten(vi, M.minimizer) - logdensity_optimum = Optimisation.OptimLogDensity( - f.ldf.model, vi_optimum, f.ldf.context - ) + logdensity_optimum = Optimisation.OptimLogDensity(f.ldf.model, vi_optimum; adtype=f.ldf.adtype) vns_vals_iter = Turing.Inference.getparams(f.ldf.model, vi_optimum) varnames = map(Symbol ∘ first, vns_vals_iter) vals = map(last, vns_vals_iter) diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 0cbb45b48f..59e32fe951 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -22,8 +22,6 @@ using DynamicPPL: SampleFromPrior, SampleFromUniform, DefaultContext, - PriorContext, - LikelihoodContext, set_flag!, unset_flag! using Distributions, Libtask, Bijectors @@ -75,7 +73,6 @@ export InferenceAlgorithm, RepeatSampler, Prior, assume, - observe, predict, externalsampler @@ -182,12 +179,10 @@ function AbstractMCMC.step( state=nothing; kwargs..., ) + vi = VarInfo() + vi = DynamicPPL.setaccs!!(vi, (DynamicPPL.LogPriorAccumulator(),)) vi = last( - DynamicPPL.evaluate!!( - model, - VarInfo(), - SamplingContext(rng, DynamicPPL.SampleFromPrior(), DynamicPPL.PriorContext()), - ), + DynamicPPL.evaluate!!(model, vi, SamplingContext(rng, DynamicPPL.SampleFromPrior())), ) return vi, nothing end diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index 5448173486..5205772032 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -49,7 +49,7 @@ function AbstractMCMC.step( rng, EllipticalSliceSampling.ESSModel( ESSPrior(model, spl, vi), - DynamicPPL.LogDensityFunction( + DynamicPPL.LogDensityFunction{:LogLikelihood}( model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext()) ), ), @@ -59,7 +59,7 @@ function AbstractMCMC.step( # update sample and log-likelihood vi = DynamicPPL.unflatten(vi, sample) - vi = setlogp!!(vi, state.loglikelihood) + vi = setloglikelihood!!(vi, state.loglikelihood) return Transition(model, vi), vi end @@ -108,20 +108,12 @@ end # Mean of prior distribution Distributions.mean(p::ESSPrior) = p.μ -# Evaluate log-likelihood of proposals -const ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} = - DynamicPPL.LogDensityFunction{M,V,<:DynamicPPL.SamplingContext{<:S},AD} where {AD} - -(ℓ::ESSLogLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ, f) - function DynamicPPL.tilde_assume( - rng::Random.AbstractRNG, ::DefaultContext, ::Sampler{<:ESS}, right, vn, vi + rng::Random.AbstractRNG, ctx::DefaultContext, ::Sampler{<:ESS}, right, vn, vi ) - return DynamicPPL.tilde_assume( - rng, LikelihoodContext(), SampleFromPrior(), right, vn, vi - ) + return DynamicPPL.tilde_assume(rng, ctx, SampleFromPrior(), right, vn, vi) end -function DynamicPPL.tilde_observe(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vi) - return DynamicPPL.tilde_observe(ctx, SampleFromPrior(), right, left, vi) +function DynamicPPL.tilde_observe!!(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vn, vi) + return DynamicPPL.tilde_observe!!(ctx, SampleFromPrior(), right, left, vn, vi) end diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index c1d6cd6cff..d110d4decb 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -32,7 +32,7 @@ can_be_wrapped(ctx::DynamicPPL.PrefixContext) = can_be_wrapped(ctx.context) # # Purpose: avoid triggering resampling of variables we're conditioning on. # - Using standard `DynamicPPL.condition` results in conditioned variables being treated -# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe`. +# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe!!`. # - But `observe` is overloaded by some samplers, e.g. `CSMC`, which can lead to # undesirable behavior, e.g. `CSMC` triggering a resampling for every conditioned variable # rather than only for the "true" observations. @@ -177,16 +177,18 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) DynamicPPL.tilde_assume(child_context, right, vn, vi) elseif has_conditioned_gibbs(context, vn) # Short-circuit the tilde assume if `vn` is present in `context`. - value, lp, _ = DynamicPPL.tilde_assume( + # TODO(mhauru) Fix accumulation here. In this branch anything that gets + # accumulated just gets discarded with `_`. + value, _ = DynamicPPL.tilde_assume( child_context, right, vn, get_global_varinfo(context) ) - value, lp, vi + value, vi else # If the varname has not been conditioned on, nor is it a target variable, its # presumably a new variable that should be sampled from its prior. We need to add # this new variable to the global `varinfo` of the context, but not to the local one # being used by the current sampler. - value, lp, new_global_vi = DynamicPPL.tilde_assume( + value, new_global_vi = DynamicPPL.tilde_assume( child_context, DynamicPPL.SampleFromPrior(), right, @@ -194,7 +196,7 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) get_global_varinfo(context), ) set_global_varinfo!(context, new_global_vi) - value, lp, vi + value, vi end end diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index 098e62eb22..af271b0cfc 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -501,10 +501,6 @@ function DynamicPPL.assume( return DynamicPPL.assume(dist, vn, vi) end -function DynamicPPL.observe(::Sampler{<:Hamiltonian}, d::Distribution, value, vi) - return DynamicPPL.observe(d, value, vi) -end - #### #### Default HMC stepsize and mass matrix adaptor #### diff --git a/src/mcmc/is.jl b/src/mcmc/is.jl index d83abd173c..9ad0e1f82a 100644 --- a/src/mcmc/is.jl +++ b/src/mcmc/is.jl @@ -55,7 +55,3 @@ function DynamicPPL.assume(rng, ::Sampler{<:IS}, dist::Distribution, vn::VarName end return r, 0, vi end - -function DynamicPPL.observe(::Sampler{<:IS}, dist::Distribution, value, vi) - return logpdf(dist, value), vi -end diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index 6a03f5359f..83a29388cb 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -390,7 +390,3 @@ function DynamicPPL.assume( retval = DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi) return retval end - -function DynamicPPL.observe(spl::Sampler{<:MH}, d::Distribution, value, vi) - return DynamicPPL.observe(SampleFromPrior(), d, value, vi) -end diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 5bb1103876..b1e38b8038 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -379,10 +379,11 @@ function DynamicPPL.assume( return r, lp, vi end -function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi) - # NOTE: The `Libtask.produce` is now hit in `acclogp_observe!!`. - return logpdf(dist, value), trace_local_varinfo_maybe(vi) -end +# TODO(mhauru) Fix this. +# function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi) +# # NOTE: The `Libtask.produce` is now hit in `acclogp_observe!!`. +# return logpdf(dist, value), trace_local_varinfo_maybe(vi) +# end function DynamicPPL.acclogp!!( context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp @@ -391,12 +392,13 @@ function DynamicPPL.acclogp!!( return DynamicPPL.acclogp!!(DynamicPPL.childcontext(context), varinfo_trace, logp) end -function DynamicPPL.acclogp_observe!!( - context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp -) - Libtask.produce(logp) - return trace_local_varinfo_maybe(varinfo) -end +# TODO(mhauru) Fix this. +# function DynamicPPL.acclogp_observe!!( +# context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp +# ) +# Libtask.produce(logp) +# return trace_local_varinfo_maybe(varinfo) +# end # Convenient constructor function AdvancedPS.Trace( diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index ddcc27b876..23da8b08a6 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -43,75 +43,87 @@ Concrete type for maximum a posteriori estimation. Only used for the Optim.jl in """ struct MAP <: ModeEstimator end +# Most of these functions for LogPriorWithoutJacobianAccumulator are copied from +# LogPriorAccumulator. The only one that is different is the accumulate_assume!! one. """ - OptimizationContext{C<:AbstractContext} <: AbstractContext + LogPriorWithoutJacobianAccumulator{T} <: DynamicPPL.AbstractAccumulator -The `OptimizationContext` transforms variables to their constrained space, but -does not use the density with respect to the transformation. This context is -intended to allow an optimizer to sample in R^n freely. +Exactly like DynamicPPL.LogPriorAccumulator, but does not include the log determinant of the +Jacobian of any variable transformations. + +Used for MAP optimisation. """ -struct OptimizationContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext - context::C +struct LogPriorWithoutJacobianAccumulator{T} <: DynamicPPL.AbstractAccumulator + logp::T +end - function OptimizationContext{C}(context::C) where {C<:DynamicPPL.AbstractContext} - if !( - context isa Union{ - DynamicPPL.DefaultContext, - DynamicPPL.LikelihoodContext, - DynamicPPL.PriorContext, - } - ) - msg = """ - `OptimizationContext` supports only leaf contexts of type - `DynamicPPL.DefaultContext`, `DynamicPPL.LikelihoodContext`, - and `DynamicPPL.PriorContext` (given: `$(typeof(context)))` - """ - throw(ArgumentError(msg)) - end - return new{C}(context) - end +""" + LogPriorWithoutJacobianAccumulator{T}() + +Create a new `LogPriorWithoutJacobianAccumulator` accumulator with the log prior initialized to zero. +""" +LogPriorWithoutJacobianAccumulator{T}() where {T<:Real} = LogPriorWithoutJacobianAccumulator(zero(T)) +LogPriorWithoutJacobianAccumulator() = LogPriorWithoutJacobianAccumulator{DynamicPPL.LogProbType}() + +function Base.show(io::IO, acc::LogPriorWithoutJacobianAccumulator) + return print(io, "LogPriorWithoutJacobianAccumulator($(repr(acc.logp)))") end -OptimizationContext(ctx::DynamicPPL.AbstractContext) = OptimizationContext{typeof(ctx)}(ctx) +# We use the same name for LogPriorWithoutJacobianAccumulator as for LogPriorAccumulator. +# This has three effects: +# 1. You can't have a VarInfo with both accumulator types. +# 2. When you call functions like `getlogprior` on a VarInfo, it will return the one without +# the Jacobian term, as if that was the usual log prior. +# 3. This may cause a small number of invalidations in DynamicPPL. I haven't checked, but I +# suspect they will be negligible. +# TODO(mhauru) Not sure I like this solution. It's kinda glib, but might confuse a reader +# of the code who expects things like `getlogprior` to always get the LogPriorAccumulator +# contents. Another solution would be welcome, but would need to play nicely with how +# LogDenssityFunction works, since it calls `getlogprior` explictily. +DynamicPPL.accumulator_name(::Type{<:LogPriorWithoutJacobianAccumulator}) = :LogPrior + +DynamicPPL.split(::LogPriorWithoutJacobianAccumulator{T}) where {T} = LogPriorWithoutJacobianAccumulator(zero(T)) + +function DynamicPPL.combine(acc::LogPriorWithoutJacobianAccumulator, acc2::LogPriorWithoutJacobianAccumulator) + return LogPriorWithoutJacobianAccumulator(acc.logp + acc2.logp) +end -DynamicPPL.NodeTrait(::OptimizationContext) = DynamicPPL.IsLeaf() +function Base.:+(acc1::LogPriorWithoutJacobianAccumulator, acc2::LogPriorWithoutJacobianAccumulator) + return LogPriorWithoutJacobianAccumulator(acc1.logp + acc2.logp) +end -function DynamicPPL.tilde_assume(ctx::OptimizationContext, dist, vn, vi) - r = vi[vn, dist] - lp = if ctx.context isa Union{DynamicPPL.DefaultContext,DynamicPPL.PriorContext} - # MAP - Distributions.logpdf(dist, r) - else - # MLE - 0 - end - return r, lp, vi +Base.zero(acc::LogPriorWithoutJacobianAccumulator) = LogPriorWithoutJacobianAccumulator(zero(acc.logp)) + +function DynamicPPL.accumulate_assume!!(acc::LogPriorWithoutJacobianAccumulator, val, logjac, vn, right) + return acc + LogPriorWithoutJacobianAccumulator(Distributions.logpdf(right, val)) end +DynamicPPL.accumulate_observe!!(acc::LogPriorWithoutJacobianAccumulator, right, left, vn) = acc -function DynamicPPL.tilde_observe( - ctx::OptimizationContext{<:DynamicPPL.PriorContext}, args... -) - return DynamicPPL.tilde_observe(ctx.context, args...) +function Base.convert(::Type{LogPriorWithoutJacobianAccumulator{T}}, acc::LogPriorWithoutJacobianAccumulator) where {T} + return LogPriorWithoutJacobianAccumulator(convert(T, acc.logp)) +end + +function DynamicPPL.convert_eltype(::Type{T}, acc::LogPriorWithoutJacobianAccumulator) where {T} + return LogPriorWithoutJacobianAccumulator(convert(T, acc.logp)) end """ OptimLogDensity{ M<:DynamicPPL.Model, - V<:DynamicPPL.VarInfo, - C<:OptimizationContext, + V<:DynamicPPL.AbstractVarInfo, AD<:ADTypes.AbstractADType } A struct that wraps a single LogDensityFunction. Can be invoked either using ```julia -OptimLogDensity(model, varinfo, ctx; adtype=adtype) +OptimLogDensity(model, varinfo; adtype=adtype) ``` or ```julia -OptimLogDensity(model, ctx; adtype=adtype) +OptimLogDensity(model; adtype=adtype) ``` If not specified, `adtype` defaults to `AutoForwardDiff()`. @@ -129,37 +141,20 @@ the underlying LogDensityFunction at the point `z`. This is done to satisfy the Optim.jl interface. ```julia -optim_ld = OptimLogDensity(model, varinfo, ctx) +optim_ld = OptimLogDensity(model, varinfo) optim_ld(z) # returns -logp ``` """ struct OptimLogDensity{ M<:DynamicPPL.Model, - V<:DynamicPPL.VarInfo, - C<:OptimizationContext, + V<:DynamicPPL.AbstractVarInfo, AD<:ADTypes.AbstractADType, } - ldf::DynamicPPL.LogDensityFunction{M,V,C,AD} -end - -function OptimLogDensity( - model::DynamicPPL.Model, - vi::DynamicPPL.VarInfo, - ctx::OptimizationContext; - adtype::ADTypes.AbstractADType=AutoForwardDiff(), -) - return OptimLogDensity(DynamicPPL.LogDensityFunction(model, vi, ctx; adtype=adtype)) + ldf::DynamicPPL.LogDensityFunction{M,V,AD} end -# No varinfo -function OptimLogDensity( - model::DynamicPPL.Model, - ctx::OptimizationContext; - adtype::ADTypes.AbstractADType=AutoForwardDiff(), -) - return OptimLogDensity( - DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx; adtype=adtype) - ) +function OptimLogDensity(model::DynamicPPL.Model, vi::DynamicPPL.AbstractVarInfo=DynamicPPL.VarInfo(model); adtype=AutoForwardDiff()) + return OptimLogDensity(DynamicPPL.LogDensityFunction(model, vi; adtype=adtype)) end """ @@ -325,10 +320,11 @@ function StatsBase.informationmatrix( # Convert the values to their unconstrained states to make sure the # Hessian is computed with respect to the untransformed parameters. - linked = DynamicPPL.istrans(m.f.ldf.varinfo) + old_ldf = m.f.ldf + linked = DynamicPPL.istrans(old_ldf.varinfo) if linked - new_vi = DynamicPPL.invlink!!(m.f.ldf.varinfo, m.f.ldf.model) - new_f = OptimLogDensity(m.f.ldf.model, new_vi, m.f.ldf.context) + new_vi = DynamicPPL.invlink!!(old_ldf.varinfo, old_ldf.model) + new_f = OptimLogDensity(old_ldf.model, new_vi; adtype=old_ldf.adtype) m = Accessors.@set m.f = new_f end @@ -339,8 +335,9 @@ function StatsBase.informationmatrix( # Link it back if we invlinked it. if linked - new_vi = DynamicPPL.link!!(m.f.ldf.varinfo, m.f.ldf.model) - new_f = OptimLogDensity(m.f.ldf.model, new_vi, m.f.ldf.context) + invlinked_ldf = m.f.ldf + new_vi = DynamicPPL.link!!(invlinked_ldf.varinfo, invlinked_ldf.model) + new_f = OptimLogDensity(invlinked_ldf.model, new_vi; adtype=invlinked_ldf.adtype) m = Accessors.@set m.f = new_f end @@ -560,12 +557,11 @@ function estimate_mode( # Create an OptimLogDensity object that can be used to evaluate the objective function, # i.e. the negative log density. - inner_context = if estimator isa MAP - DynamicPPL.DefaultContext() + accs = if estimator isa MAP + (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator()) else - DynamicPPL.LikelihoodContext() + (DynamicPPL.LogLikelihoodAccumulator(),) end - ctx = OptimizationContext(inner_context) # Set its VarInfo to the initial parameters. # TODO(penelopeysm): Unclear if this is really needed? Any time that logp is calculated @@ -574,6 +570,7 @@ function estimate_mode( # directly on the fields of the LogDensityFunction vi = DynamicPPL.VarInfo(model) vi = DynamicPPL.unflatten(vi, initial_params) + vi = DynamicPPL.setaccs!!(vi, accs) # Link the varinfo if needed. # TODO(mhauru) We currently couple together the questions of whether the user specified @@ -585,7 +582,7 @@ function estimate_mode( vi = DynamicPPL.link(vi, model) end - log_density = OptimLogDensity(model, vi, ctx) + log_density = OptimLogDensity(model, vi) prob = Optimization.OptimizationProblem(log_density, adtype, constraints) solution = Optimization.solve(prob, solver; kwargs...) diff --git a/test/Project.toml b/test/Project.toml index df0af4c978..921a8a54c3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -42,7 +42,7 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" AbstractMCMC = "5" AbstractPPL = "0.9, 0.10, 0.11" AdvancedMH = "0.6, 0.7, 0.8" -AdvancedPS = "=0.6.0" +AdvancedPS = "0.6.0" AdvancedVI = "0.2" Aqua = "0.8" BangBang = "0.4" @@ -52,7 +52,7 @@ Combinatorics = "1" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -DynamicPPL = "0.36" +DynamicPPL = "0.37" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" HypothesisTests = "0.11" diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 36044a9e73..05b994a563 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -119,36 +119,6 @@ using Turing check_gdemo(chn3_contd) end - @testset "Contexts" begin - # Test LikelihoodContext - @model function testmodel1(x) - a ~ Beta() - lp1 = getlogp(__varinfo__) - x[1] ~ Bernoulli(a) - return global loglike = getlogp(__varinfo__) - lp1 - end - model = testmodel1([1.0]) - varinfo = Turing.VarInfo(model) - model(varinfo, Turing.SampleFromPrior(), Turing.LikelihoodContext()) - @test getlogp(varinfo) == loglike - - # Test MiniBatchContext - @model function testmodel2(x) - a ~ Beta() - return x[1] ~ Bernoulli(a) - end - model = testmodel2([1.0]) - varinfo1 = Turing.VarInfo(model) - varinfo2 = deepcopy(varinfo1) - model(varinfo1, Turing.SampleFromPrior(), Turing.LikelihoodContext()) - model( - varinfo2, - Turing.SampleFromPrior(), - Turing.MiniBatchContext(Turing.LikelihoodContext(), 10), - ) - @test isapprox(getlogp(varinfo2) / getlogp(varinfo1), 10) - end - @testset "Prior" begin N = 10_000 @@ -180,21 +150,6 @@ using Turing @test mean(x[:s][1] for x in chains) ≈ 3 atol = 0.11 @test mean(x[:m][1] for x in chains) ≈ 0 atol = 0.1 end - - @testset "#2169" begin - # Not exactly the same as the issue, but similar. - @model function issue2169_model() - if DynamicPPL.leafcontext(__context__) isa DynamicPPL.PriorContext - x ~ Normal(0, 1) - else - x ~ Normal(1000, 1) - end - end - - model = issue2169_model() - chain = sample(StableRNG(seed), model, Prior(), 10) - @test all(mean(chain[:x]) .< 5) - end end @testset "chain ordering" begin diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 4d884b1637..754ec4a932 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -210,6 +210,8 @@ using Turing @test Array(res1) == Array(res2) == Array(res3) end + # TODO(mhauru) Do we give up being able to sample from only prior/likelihood like this, + # or do we implement some way to pass `whichlogprob=:LogPrior` through `sample`? @testset "prior" begin @model function demo_hmc_prior() # NOTE: Used to use `InverseGamma(2, 3)` but this has infinite variance diff --git a/test/mcmc/mh.jl b/test/mcmc/mh.jl index d190e589a5..1614a9446c 100644 --- a/test/mcmc/mh.jl +++ b/test/mcmc/mh.jl @@ -265,6 +265,8 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) @test !DynamicPPL.islinked(vi) end + # TODO(mhauru) Do we give up being able to sample from only prior/likelihood like this, + # or do we implement some way to pass `whichlogprob=:LogPrior` through `sample`? @testset "prior" begin alg = MH() gdemo_default_prior = DynamicPPL.contextualize( diff --git a/test/optimisation/Optimisation.jl b/test/optimisation/Optimisation.jl index 9894d621ce..0b22f34534 100644 --- a/test/optimisation/Optimisation.jl +++ b/test/optimisation/Optimisation.jl @@ -28,27 +28,6 @@ using Turing # Issue: https://discourse.julialang.org/t/two-equivalent-conditioning-syntaxes-giving-different-likelihood-values/100320 @testset "OptimizationContext" begin - # Used for testing how well it works with nested contexts. - struct OverrideContext{C,T1,T2} <: DynamicPPL.AbstractContext - context::C - logprior_weight::T1 - loglikelihood_weight::T2 - end - DynamicPPL.NodeTrait(::OverrideContext) = DynamicPPL.IsParent() - DynamicPPL.childcontext(parent::OverrideContext) = parent.context - DynamicPPL.setchildcontext(parent::OverrideContext, child) = - OverrideContext(child, parent.logprior_weight, parent.loglikelihood_weight) - - # Only implement what we need for the models above. - function DynamicPPL.tilde_assume(context::OverrideContext, right, vn, vi) - value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) - return value, context.logprior_weight, vi - end - function DynamicPPL.tilde_observe(context::OverrideContext, right, left, vi) - logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) - return context.loglikelihood_weight, vi - end - @model function model1(x) μ ~ Uniform(0, 2) return x ~ LogNormal(μ, 1) @@ -65,48 +44,34 @@ using Turing @testset "With ConditionContext" begin m1 = model1(x) m2 = model2() | (x=x,) - ctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - @test Turing.Optimisation.OptimLogDensity(m1, ctx)(w) == - Turing.Optimisation.OptimLogDensity(m2, ctx)(w) + @test Turing.Optimisation.OptimLogDensity(m1)(w) == + Turing.Optimisation.OptimLogDensity(m2)(w) end @testset "With prefixes" begin vn = @varname(inner) m1 = prefix(model1(x), vn) m2 = prefix((model2() | (x=x,)), vn) - ctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - @test Turing.Optimisation.OptimLogDensity(m1, ctx)(w) == - Turing.Optimisation.OptimLogDensity(m2, ctx)(w) - end - - @testset "Weighted" begin - function override(model) - return DynamicPPL.contextualize( - model, OverrideContext(model.context, 100, 1) - ) - end - m1 = override(model1(x)) - m2 = override(model2() | (x=x,)) - ctx = Turing.Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - @test Turing.Optimisation.OptimLogDensity(m1, ctx)(w) == - Turing.Optimisation.OptimLogDensity(m2, ctx)(w) + @test Turing.Optimisation.OptimLogDensity(m1)(w) == + Turing.Optimisation.OptimLogDensity(m2)(w) end @testset "Default, Likelihood, Prior Contexts" begin m1 = model1(x) - defctx = Turing.Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - llhctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - prictx = Turing.Optimisation.OptimizationContext(DynamicPPL.PriorContext()) + vi = DynamicPPL.VarInfo(m1) + vi_joint = DynamicPPL.setaccs!!(deepcopy(vi), (LogPriorWithoutJacobianAccumulator(), LogLikelihoodAccumulator())) + vi_prior = DynamicPPL.setaccs!!(deepcopy(vi), (LogPriorWithoutJacobianAccumulator(),)) + vi_likelihood = DynamicPPL.setaccs!!(deepcopy(vi), (LogLikelihoodAccumulator(),)) a = [0.3] - @test Turing.Optimisation.OptimLogDensity(m1, defctx)(a) == - Turing.Optimisation.OptimLogDensity(m1, llhctx)(a) + - Turing.Optimisation.OptimLogDensity(m1, prictx)(a) + @test Turing.Optimisation.OptimLogDensity(m1, vi_joint)(a) == + Turing.Optimisation.OptimLogDensity(m1, vi_prior)(a) + + Turing.Optimisation.OptimLogDensity(m1, vi_likelihood)(a) - # test that PriorContext is calculating the right thing - @test Turing.Optimisation.OptimLogDensity(m1, prictx)([0.3]) ≈ + # test that the prior accumulator is calculating the right thing + @test Turing.Optimisation.OptimLogDensity(m1, vi_prior)([0.3]) ≈ -Distributions.logpdf(Uniform(0, 2), 0.3) - @test Turing.Optimisation.OptimLogDensity(m1, prictx)([-0.3]) ≈ + @test Turing.Optimisation.OptimLogDensity(m1, vi_prior)([-0.3]) ≈ -Distributions.logpdf(Uniform(0, 2), -0.3) end end @@ -664,8 +629,8 @@ using Turing return nothing end m = saddle_model() - ctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - optim_ld = Turing.Optimisation.OptimLogDensity(m, ctx) + vi = DynamicPPL.setaccs!!(DynamicPPL.VarInfo(m), (LogLikelihoodAccumulator(),)) + optim_ld = Turing.Optimisation.OptimLogDensity(m, vi) vals = Turing.Optimisation.NamedArrays.NamedArray([0.0, 0.0]) m = Turing.Optimisation.ModeResult(vals, nothing, 0.0, optim_ld) ct = coeftable(m) diff --git a/test/test_utils/ad_utils.jl b/test/test_utils/ad_utils.jl index 309276407a..36bd7a9f68 100644 --- a/test/test_utils/ad_utils.jl +++ b/test/test_utils/ad_utils.jl @@ -147,35 +147,27 @@ end # child context. function DynamicPPL.tilde_assume(context::ADTypeCheckContext, right, vn, vi) - value, logp, vi = DynamicPPL.tilde_assume( + value, vi = DynamicPPL.tilde_assume( DynamicPPL.childcontext(context), right, vn, vi ) check_adtype(context, vi) - return value, logp, vi + return value, vi end function DynamicPPL.tilde_assume( rng::Random.AbstractRNG, context::ADTypeCheckContext, sampler, right, vn, vi ) - value, logp, vi = DynamicPPL.tilde_assume( + value, vi = DynamicPPL.tilde_assume( rng, DynamicPPL.childcontext(context), sampler, right, vn, vi ) check_adtype(context, vi) - return value, logp, vi + return value, vi end -function DynamicPPL.tilde_observe(context::ADTypeCheckContext, right, left, vi) - logp, vi = DynamicPPL.tilde_observe(DynamicPPL.childcontext(context), right, left, vi) +function DynamicPPL.tilde_observe!!(context::ADTypeCheckContext, right, left, vn, vi) + left, vi = DynamicPPL.tilde_observe(DynamicPPL.childcontext(context), right, left, vn, vi) check_adtype(context, vi) - return logp, vi -end - -function DynamicPPL.tilde_observe(context::ADTypeCheckContext, sampler, right, left, vi) - logp, vi = DynamicPPL.tilde_observe( - DynamicPPL.childcontext(context), sampler, right, left, vi - ) - check_adtype(context, vi) - return logp, vi + return left, vi end # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # From 7948bc55752b860be48e63cc3b66f6577ea7f72a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 20 May 2025 17:53:05 +0100 Subject: [PATCH 2/2] More DPPL 0.37 compat work, WIP --- ext/TuringOptimExt.jl | 26 +++---- src/essential/container.jl | 7 +- src/mcmc/Inference.jl | 13 ++-- src/mcmc/ess.jl | 4 +- src/mcmc/hmc.jl | 11 ++- src/mcmc/particle_mcmc.jl | 16 ++-- src/optimisation/Optimisation.jl | 124 +++++++++++++++++++++--------- test/optimisation/Optimisation.jl | 45 ++++++----- test/test_utils/ad_utils.jl | 8 +- 9 files changed, 161 insertions(+), 93 deletions(-) diff --git a/ext/TuringOptimExt.jl b/ext/TuringOptimExt.jl index 9f5c51a2b4..635eb89111 100644 --- a/ext/TuringOptimExt.jl +++ b/ext/TuringOptimExt.jl @@ -34,8 +34,7 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),)) - f = Optimisation.OptimLogDensity(model, vi) + f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood) init_vals = DynamicPPL.getparams(f.ldf) optimizer = Optim.LBFGS() return _mle_optimize(model, init_vals, optimizer, options; kwargs...) @@ -57,8 +56,7 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),)) - f = Optimisation.OptimLogDensity(model, vi) + f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood) init_vals = DynamicPPL.getparams(f.ldf) return _mle_optimize(model, init_vals, optimizer, options; kwargs...) end @@ -74,8 +72,7 @@ function Optim.optimize( end function _mle_optimize(model::DynamicPPL.Model, args...; kwargs...) - vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),)) - f = Optimisation.OptimLogDensity(model, vi) + f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood) return _optimize(f, args...; kwargs...) end @@ -105,8 +102,7 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),)) - f = Optimisation.OptimLogDensity(model, vi) + f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian) init_vals = DynamicPPL.getparams(f.ldf) optimizer = Optim.LBFGS() return _map_optimize(model, init_vals, optimizer, options; kwargs...) @@ -128,8 +124,7 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),)) - f = Optimisation.OptimLogDensity(model, vi) + f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian) init_vals = DynamicPPL.getparams(f.ldf) return _map_optimize(model, init_vals, optimizer, options; kwargs...) end @@ -145,8 +140,7 @@ function Optim.optimize( end function _map_optimize(model::DynamicPPL.Model, args...; kwargs...) - vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),)) - f = Optimisation.OptimLogDensity(model, vi) + f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian) return _optimize(f, args...; kwargs...) end @@ -169,7 +163,9 @@ function _optimize( # whether initialisation is really necessary at all vi = DynamicPPL.unflatten(f.ldf.varinfo, init_vals) vi = DynamicPPL.link(vi, f.ldf.model) - f = Optimisation.OptimLogDensity(f.ldf.model, vi; adtype=f.ldf.adtype) + f = Optimisation.OptimLogDensity( + f.ldf.model, f.ldf.getlogdensity, vi; adtype=f.ldf.adtype + ) init_vals = DynamicPPL.getparams(f.ldf) # Optimize! @@ -186,7 +182,9 @@ function _optimize( # Get the optimum in unconstrained space. `getparams` does the invlinking. vi = f.ldf.varinfo vi_optimum = DynamicPPL.unflatten(vi, M.minimizer) - logdensity_optimum = Optimisation.OptimLogDensity(f.ldf.model, vi_optimum; adtype=f.ldf.adtype) + logdensity_optimum = Optimisation.OptimLogDensity( + f.ldf.model, f.ldf.getlogdensity, vi_optimum; adtype=f.ldf.adtype + ) vns_vals_iter = Turing.Inference.getparams(f.ldf.model, vi_optimum) varnames = map(Symbol ∘ first, vns_vals_iter) vals = map(last, vns_vals_iter) diff --git a/src/essential/container.jl b/src/essential/container.jl index a1012d471f..5c78e110fa 100644 --- a/src/essential/container.jl +++ b/src/essential/container.jl @@ -28,7 +28,8 @@ function AdvancedPS.advance!( trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}, isref::Bool=false ) # Make sure we load/reset the rng in the new replaying mechanism - DynamicPPL.increment_num_produce!(trace.model.f.varinfo) + # TODO(mhauru) Stop ignoring the return value. + DynamicPPL.increment_num_produce!!(trace.model.f.varinfo) isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng) score = consume(trace.model.ctask) if score === nothing @@ -44,11 +45,13 @@ function AdvancedPS.delete_retained!(trace::TracedModel) end function AdvancedPS.reset_model(trace::TracedModel) - DynamicPPL.reset_num_produce!(trace.varinfo) + new_vi = DynamicPPL.reset_num_produce!!(trace.varinfo) + trace = TracedModel(trace.model, trace.sampler, new_vi, trace.evaluator) return trace end function AdvancedPS.reset_logprob!(trace::TracedModel) + # TODO(mhauru) Stop ignoring the return value. DynamicPPL.resetlogp!!(trace.model.varinfo) return trace end diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 59e32fe951..5754e1ba4f 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -14,6 +14,7 @@ using DynamicPPL: push!!, setlogp!!, getlogp, + getlogjoint, VarName, getsym, getdist, @@ -182,7 +183,7 @@ function AbstractMCMC.step( vi = VarInfo() vi = DynamicPPL.setaccs!!(vi, (DynamicPPL.LogPriorAccumulator(),)) vi = last( - DynamicPPL.evaluate!!(model, vi, SamplingContext(rng, DynamicPPL.SampleFromPrior())), + DynamicPPL.evaluate!!(model, vi, SamplingContext(rng, DynamicPPL.SampleFromPrior())) ) return vi, nothing end @@ -223,7 +224,7 @@ end Transition(θ, lp) = Transition(θ, lp, nothing) function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, t) θ = getparams(model, vi) - lp = getlogp(vi) + lp = getlogjoint(vi) return Transition(θ, lp, getstats(t)) end @@ -236,10 +237,10 @@ function metadata(t::Transition) end end -DynamicPPL.getlogp(t::Transition) = t.lp +DynamicPPL.getlogjoint(t::Transition) = t.lp # Metadata of VarInfo object -metadata(vi::AbstractVarInfo) = (lp=getlogp(vi),) +metadata(vi::AbstractVarInfo) = (lp=getlogjoint(vi),) # TODO: Implement additional checks for certain samplers, e.g. # HMC not supporting discrete parameters. @@ -376,7 +377,7 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector) end function get_transition_extras(ts::AbstractVector{<:VarInfo}) - valmat = reshape([getlogp(t) for t in ts], :, 1) + valmat = reshape([getlogjoint(t) for t in ts], :, 1) return [:lp], valmat end @@ -589,7 +590,7 @@ julia> chain = Chains(randn(2, 1, 1), ["m"]); # 2 samples of `m` julia> transitions = Turing.Inference.transitions_from_chain(m, chain); -julia> [Turing.Inference.getlogp(t) for t in transitions] # extract the logjoints +julia> [Turing.Inference.getlogjoint(t) for t in transitions] # extract the logjoints 2-element Array{Float64,1}: -3.6294991938628374 -2.5697948166987845 diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index 5205772032..aeafa13ad3 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -114,6 +114,8 @@ function DynamicPPL.tilde_assume( return DynamicPPL.tilde_assume(rng, ctx, SampleFromPrior(), right, vn, vi) end -function DynamicPPL.tilde_observe!!(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vn, vi) +function DynamicPPL.tilde_observe!!( + ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vn, vi +) return DynamicPPL.tilde_observe!!(ctx, SampleFromPrior(), right, left, vn, vi) end diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index af271b0cfc..6e5ca3fcfa 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -199,7 +199,7 @@ function DynamicPPL.initialstep( end # Cache current log density. - log_density_old = getlogp(vi) + log_density_old = getloglikelihood(vi) # Find good eps if not provided one if iszero(spl.alg.ϵ) @@ -227,10 +227,12 @@ function DynamicPPL.initialstep( # Update `vi` based on acceptance if t.stat.is_accept vi = DynamicPPL.unflatten(vi, t.z.θ) - vi = setlogp!!(vi, t.stat.log_density) + # TODO(mhauru) Is setloglikelihood! the right thing here? + vi = setloglikelihood!!(vi, t.stat.log_density) else vi = DynamicPPL.unflatten(vi, theta) - vi = setlogp!!(vi, log_density_old) + # TODO(mhauru) Is setloglikelihood! the right thing here? + vi = setloglikelihood!!(vi, log_density_old) end transition = Transition(model, vi, t) @@ -275,7 +277,8 @@ function AbstractMCMC.step( vi = state.vi if t.stat.is_accept vi = DynamicPPL.unflatten(vi, t.z.θ) - vi = setlogp!!(vi, t.stat.log_density) + # TODO(mhauru) Is setloglikelihood! the right thing here? + vi = setloglikelihood!!(vi, t.stat.log_density) end # Compute next transition and state. diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index b1e38b8038..e1f4a1cfa9 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -118,10 +118,10 @@ function DynamicPPL.initialstep( kwargs..., ) # Reset the VarInfo. - reset_num_produce!(vi) + vi = reset_num_produce!!(vi) set_retained_vns_del!(vi) - resetlogp!!(vi) - empty!!(vi) + vi = resetlogp!!(vi) + vi = empty!!(vi) # Create a new set of particles. particles = AdvancedPS.ParticleContainer( @@ -252,9 +252,9 @@ function DynamicPPL.initialstep( kwargs..., ) # Reset the VarInfo before new sweep - reset_num_produce!(vi) + vi = reset_num_produce!(vi) set_retained_vns_del!(vi) - resetlogp!!(vi) + vi = resetlogp!!(vi) # Create a new set of particles num_particles = spl.alg.nparticles @@ -284,8 +284,8 @@ function AbstractMCMC.step( ) # Reset the VarInfo before new sweep. vi = state.vi - reset_num_produce!(vi) - resetlogp!!(vi) + vi = reset_num_produce!(vi) + vi = resetlogp!!(vi) # Create reference particle for which the samples will be retained. reference = AdvancedPS.forkr(AdvancedPS.Trace(model, spl, vi, state.rng)) @@ -408,7 +408,7 @@ function AdvancedPS.Trace( rng::AdvancedPS.TracedRNG, ) newvarinfo = deepcopy(varinfo) - DynamicPPL.reset_num_produce!(newvarinfo) + newvarinfo = DynamicPPL.reset_num_produce!!(newvarinfo) tmodel = Turing.Essential.TracedModel(model, sampler, newvarinfo, rng) newtrace = AdvancedPS.Trace(tmodel, rng) diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index 23da8b08a6..80582019ea 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -62,56 +62,99 @@ end Create a new `LogPriorWithoutJacobianAccumulator` accumulator with the log prior initialized to zero. """ -LogPriorWithoutJacobianAccumulator{T}() where {T<:Real} = LogPriorWithoutJacobianAccumulator(zero(T)) -LogPriorWithoutJacobianAccumulator() = LogPriorWithoutJacobianAccumulator{DynamicPPL.LogProbType}() +LogPriorWithoutJacobianAccumulator{T}() where {T<:Real} = + LogPriorWithoutJacobianAccumulator(zero(T)) +function LogPriorWithoutJacobianAccumulator() + return LogPriorWithoutJacobianAccumulator{DynamicPPL.LogProbType}() +end function Base.show(io::IO, acc::LogPriorWithoutJacobianAccumulator) return print(io, "LogPriorWithoutJacobianAccumulator($(repr(acc.logp)))") end -# We use the same name for LogPriorWithoutJacobianAccumulator as for LogPriorAccumulator. -# This has three effects: -# 1. You can't have a VarInfo with both accumulator types. -# 2. When you call functions like `getlogprior` on a VarInfo, it will return the one without -# the Jacobian term, as if that was the usual log prior. -# 3. This may cause a small number of invalidations in DynamicPPL. I haven't checked, but I -# suspect they will be negligible. -# TODO(mhauru) Not sure I like this solution. It's kinda glib, but might confuse a reader -# of the code who expects things like `getlogprior` to always get the LogPriorAccumulator -# contents. Another solution would be welcome, but would need to play nicely with how -# LogDenssityFunction works, since it calls `getlogprior` explictily. -DynamicPPL.accumulator_name(::Type{<:LogPriorWithoutJacobianAccumulator}) = :LogPrior +function DynamicPPL.accumulator_name(::Type{<:LogPriorWithoutJacobianAccumulator}) + return :LogPriorWithoutJacobian +end -DynamicPPL.split(::LogPriorWithoutJacobianAccumulator{T}) where {T} = LogPriorWithoutJacobianAccumulator(zero(T)) +function DynamicPPL.split(::LogPriorWithoutJacobianAccumulator{T}) where {T} + return LogPriorWithoutJacobianAccumulator(zero(T)) +end -function DynamicPPL.combine(acc::LogPriorWithoutJacobianAccumulator, acc2::LogPriorWithoutJacobianAccumulator) +function DynamicPPL.combine( + acc::LogPriorWithoutJacobianAccumulator, acc2::LogPriorWithoutJacobianAccumulator +) return LogPriorWithoutJacobianAccumulator(acc.logp + acc2.logp) end -function Base.:+(acc1::LogPriorWithoutJacobianAccumulator, acc2::LogPriorWithoutJacobianAccumulator) +function Base.:+( + acc1::LogPriorWithoutJacobianAccumulator, acc2::LogPriorWithoutJacobianAccumulator +) return LogPriorWithoutJacobianAccumulator(acc1.logp + acc2.logp) end -Base.zero(acc::LogPriorWithoutJacobianAccumulator) = LogPriorWithoutJacobianAccumulator(zero(acc.logp)) +function Base.zero(acc::LogPriorWithoutJacobianAccumulator) + return LogPriorWithoutJacobianAccumulator(zero(acc.logp)) +end -function DynamicPPL.accumulate_assume!!(acc::LogPriorWithoutJacobianAccumulator, val, logjac, vn, right) +function DynamicPPL.accumulate_assume!!( + acc::LogPriorWithoutJacobianAccumulator, val, logjac, vn, right +) return acc + LogPriorWithoutJacobianAccumulator(Distributions.logpdf(right, val)) end -DynamicPPL.accumulate_observe!!(acc::LogPriorWithoutJacobianAccumulator, right, left, vn) = acc +function DynamicPPL.accumulate_observe!!( + acc::LogPriorWithoutJacobianAccumulator, right, left, vn +) + return acc +end -function Base.convert(::Type{LogPriorWithoutJacobianAccumulator{T}}, acc::LogPriorWithoutJacobianAccumulator) where {T} +function Base.convert( + ::Type{LogPriorWithoutJacobianAccumulator{T}}, acc::LogPriorWithoutJacobianAccumulator +) where {T} return LogPriorWithoutJacobianAccumulator(convert(T, acc.logp)) end -function DynamicPPL.convert_eltype(::Type{T}, acc::LogPriorWithoutJacobianAccumulator) where {T} +function DynamicPPL.convert_eltype( + ::Type{T}, acc::LogPriorWithoutJacobianAccumulator +) where {T} return LogPriorWithoutJacobianAccumulator(convert(T, acc.logp)) end +function getlogprior_without_jacobian(vi::DynamicPPL.AbstractVarInfo) + acc = DynamicPPL.getacc(vi, Val(:LogPriorWithoutJacobian)) + return acc.logp +end + +function getlogjoint_without_jacobian(vi::DynamicPPL.AbstractVarInfo) + return getlogprior_without_jacobian(vi) + DynamicPPL.getloglikelihood(vi) +end + +# This is called when constructing a LogDensityFunction, and ensures the VarInfo has the +# right accumulators. +function DynamicPPL.ldf_default_varinfo( + model::DynamicPPL.Model, ::typeof(getlogprior_without_jacobian) +) + vi = DynamicPPL.VarInfo(model) + vi = DynamicPPL.setaccs!!(vi, (LogPriorWithoutJacobianAccumulator(),)) + return vi +end + +function DynamicPPL.ldf_default_varinfo( + model::DynamicPPL.Model, ::typeof(getlogjoint_without_jacobian) +) + vi = DynamicPPL.VarInfo(model) + vi = DynamicPPL.setaccs!!( + vi, (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator()) + ) + return vi +end + """ OptimLogDensity{ M<:DynamicPPL.Model, + F<:Function, V<:DynamicPPL.AbstractVarInfo, - AD<:ADTypes.AbstractADType + C<:DynamicPPL.AbstractContext, + AD<:ADTypes.AbstractADType, } A struct that wraps a single LogDensityFunction. Can be invoked either using @@ -147,14 +190,23 @@ optim_ld(z) # returns -logp """ struct OptimLogDensity{ M<:DynamicPPL.Model, + F<:Function, V<:DynamicPPL.AbstractVarInfo, + C<:DynamicPPL.AbstractContext, AD<:ADTypes.AbstractADType, } - ldf::DynamicPPL.LogDensityFunction{M,V,AD} + ldf::DynamicPPL.LogDensityFunction{M,F,V,C,AD} end -function OptimLogDensity(model::DynamicPPL.Model, vi::DynamicPPL.AbstractVarInfo=DynamicPPL.VarInfo(model); adtype=AutoForwardDiff()) - return OptimLogDensity(DynamicPPL.LogDensityFunction(model, vi; adtype=adtype)) +function OptimLogDensity( + model::DynamicPPL.Model, + getlogdensity::Function, + vi::DynamicPPL.AbstractVarInfo=DynamicPPL.ldf_default_varinfo(model, getlogdensity); + adtype=AutoForwardDiff(), +) + return OptimLogDensity( + DynamicPPL.LogDensityFunction(model, getlogdensity, vi; adtype=adtype) + ) end """ @@ -324,7 +376,9 @@ function StatsBase.informationmatrix( linked = DynamicPPL.istrans(old_ldf.varinfo) if linked new_vi = DynamicPPL.invlink!!(old_ldf.varinfo, old_ldf.model) - new_f = OptimLogDensity(old_ldf.model, new_vi; adtype=old_ldf.adtype) + new_f = OptimLogDensity( + old_ldf.model, old_ldf.getlogdensity, new_vi; adtype=old_ldf.adtype + ) m = Accessors.@set m.f = new_f end @@ -337,7 +391,9 @@ function StatsBase.informationmatrix( if linked invlinked_ldf = m.f.ldf new_vi = DynamicPPL.link!!(invlinked_ldf.varinfo, invlinked_ldf.model) - new_f = OptimLogDensity(invlinked_ldf.model, new_vi; adtype=invlinked_ldf.adtype) + new_f = OptimLogDensity( + invlinked_ldf.model, old_ldf.getlogdensity, new_vi; adtype=invlinked_ldf.adtype + ) m = Accessors.@set m.f = new_f end @@ -557,20 +613,16 @@ function estimate_mode( # Create an OptimLogDensity object that can be used to evaluate the objective function, # i.e. the negative log density. - accs = if estimator isa MAP - (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator()) - else - (DynamicPPL.LogLikelihoodAccumulator(),) - end + getlogdensity = + estimator isa MAP ? getlogjoint_without_jacobian : DynamicPPL.getloglikelihood # Set its VarInfo to the initial parameters. # TODO(penelopeysm): Unclear if this is really needed? Any time that logp is calculated # (using `LogDensityProblems.logdensity(ldf, x)`) the parameters in the # varinfo are completely ignored. The parameters only matter if you are calling evaluate!! # directly on the fields of the LogDensityFunction - vi = DynamicPPL.VarInfo(model) + vi = DynamicPPL.ldf_default_varinfo(model, getlogdensity) vi = DynamicPPL.unflatten(vi, initial_params) - vi = DynamicPPL.setaccs!!(vi, accs) # Link the varinfo if needed. # TODO(mhauru) We currently couple together the questions of whether the user specified @@ -582,7 +634,7 @@ function estimate_mode( vi = DynamicPPL.link(vi, model) end - log_density = OptimLogDensity(model, vi) + log_density = OptimLogDensity(model, getlogdensity, vi) prob = Optimization.OptimizationProblem(log_density, adtype, constraints) solution = Optimization.solve(prob, solver; kwargs...) diff --git a/test/optimisation/Optimisation.jl b/test/optimisation/Optimisation.jl index 0b22f34534..2fed91074e 100644 --- a/test/optimisation/Optimisation.jl +++ b/test/optimisation/Optimisation.jl @@ -27,7 +27,7 @@ using Turing hasstats(result) = result.optim_result.stats !== nothing # Issue: https://discourse.julialang.org/t/two-equivalent-conditioning-syntaxes-giving-different-likelihood-values/100320 - @testset "OptimizationContext" begin + @testset "OptimLogDensity and contexts" begin @model function model1(x) μ ~ Uniform(0, 2) return x ~ LogNormal(μ, 1) @@ -44,34 +44,44 @@ using Turing @testset "With ConditionContext" begin m1 = model1(x) m2 = model2() | (x=x,) - @test Turing.Optimisation.OptimLogDensity(m1)(w) == - Turing.Optimisation.OptimLogDensity(m2)(w) + # Doesn't matter if we use getlogjoint or getlogjoint_without_jacobian since the + # VarInfo isn't linked. + ld1 = Turing.Optimisation.OptimLogDensity( + m1, Turing.Optimisation.getlogjoint_without_jacobian + ) + ld2 = Turing.Optimisation.OptimLogDensity(m2, DynamicPPL.getlogjoint) + @test ld1(w) == ld2(w) end @testset "With prefixes" begin vn = @varname(inner) m1 = prefix(model1(x), vn) m2 = prefix((model2() | (x=x,)), vn) - @test Turing.Optimisation.OptimLogDensity(m1)(w) == - Turing.Optimisation.OptimLogDensity(m2)(w) + ld1 = Turing.Optimisation.OptimLogDensity( + m1, Turing.Optimisation.getlogjoint_without_jacobian + ) + ld2 = Turing.Optimisation.OptimLogDensity(m2, DynamicPPL.getlogjoint) + @test ld1(w) == ld2(w) end - @testset "Default, Likelihood, Prior Contexts" begin + @testset "Joint, prior, and likelihood" begin m1 = model1(x) - vi = DynamicPPL.VarInfo(m1) - vi_joint = DynamicPPL.setaccs!!(deepcopy(vi), (LogPriorWithoutJacobianAccumulator(), LogLikelihoodAccumulator())) - vi_prior = DynamicPPL.setaccs!!(deepcopy(vi), (LogPriorWithoutJacobianAccumulator(),)) - vi_likelihood = DynamicPPL.setaccs!!(deepcopy(vi), (LogLikelihoodAccumulator(),)) a = [0.3] - - @test Turing.Optimisation.OptimLogDensity(m1, vi_joint)(a) == - Turing.Optimisation.OptimLogDensity(m1, vi_prior)(a) + - Turing.Optimisation.OptimLogDensity(m1, vi_likelihood)(a) + ld_joint = Turing.Optimisation.OptimLogDensity( + m1, Turing.Optimisation.getlogjoint_without_jacobian + ) + ld_prior = Turing.Optimisation.OptimLogDensity( + m1, Turing.Optimisation.getlogprior_without_jacobian + ) + ld_likelihood = Turing.Optimisation.OptimLogDensity( + m1, DynamicPPL.getloglikelihood + ) + @test ld_joint(a) == ld_prior(a) + ld_likelihood(a) # test that the prior accumulator is calculating the right thing - @test Turing.Optimisation.OptimLogDensity(m1, vi_prior)([0.3]) ≈ + @test Turing.Optimisation.OptimLogDensity(m1, DynamicPPL.getlogprior)([0.3]) ≈ -Distributions.logpdf(Uniform(0, 2), 0.3) - @test Turing.Optimisation.OptimLogDensity(m1, vi_prior)([-0.3]) ≈ + @test Turing.Optimisation.OptimLogDensity(m1, DynamicPPL.getlogprior)([-0.3]) ≈ -Distributions.logpdf(Uniform(0, 2), -0.3) end end @@ -629,8 +639,7 @@ using Turing return nothing end m = saddle_model() - vi = DynamicPPL.setaccs!!(DynamicPPL.VarInfo(m), (LogLikelihoodAccumulator(),)) - optim_ld = Turing.Optimisation.OptimLogDensity(m, vi) + optim_ld = Turing.Optimisation.OptimLogDensity(m, DynamicPPL.getloglikelihood) vals = Turing.Optimisation.NamedArrays.NamedArray([0.0, 0.0]) m = Turing.Optimisation.ModeResult(vals, nothing, 0.0, optim_ld) ct = coeftable(m) diff --git a/test/test_utils/ad_utils.jl b/test/test_utils/ad_utils.jl index 36bd7a9f68..f1b5b3d145 100644 --- a/test/test_utils/ad_utils.jl +++ b/test/test_utils/ad_utils.jl @@ -147,9 +147,7 @@ end # child context. function DynamicPPL.tilde_assume(context::ADTypeCheckContext, right, vn, vi) - value, vi = DynamicPPL.tilde_assume( - DynamicPPL.childcontext(context), right, vn, vi - ) + value, vi = DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) check_adtype(context, vi) return value, vi end @@ -165,7 +163,9 @@ function DynamicPPL.tilde_assume( end function DynamicPPL.tilde_observe!!(context::ADTypeCheckContext, right, left, vn, vi) - left, vi = DynamicPPL.tilde_observe(DynamicPPL.childcontext(context), right, left, vn, vi) + left, vi = DynamicPPL.tilde_observe!!( + DynamicPPL.childcontext(context), right, left, vn, vi + ) check_adtype(context, vi) return left, vi end