Skip to content

Accumulators stage 2 #925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: breaking
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion benchmarks/src/DynamicPPLBenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
vi = DynamicPPL.link(vi, model)
end

f = DynamicPPL.LogDensityFunction(model, vi, context; adtype=adbackend)
f = DynamicPPL.LogDensityFunction(
model, DynamicPPL.getlogjoint, vi, context; adtype=adbackend
)
# The parameters at which we evaluate f.
θ = vi[:]

Expand Down
113 changes: 76 additions & 37 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
"""
LogDensityFunction(
model::Model,
varinfo::AbstractVarInfo=VarInfo(model),
getlogdensity::Function=getlogjoint,
varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity),
context::AbstractContext=DefaultContext();
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing
)
Expand All @@ -28,10 +29,10 @@
- and if `adtype` is provided, calculate the gradient of the log density at
that point.

At its most basic level, a LogDensityFunction wraps the model together with its
the type of varinfo to be used, as well as the evaluation context. These must
be known in order to calculate the log density (using
[`DynamicPPL.evaluate!!`](@ref)).
At its most basic level, a LogDensityFunction wraps the model together with
the type of varinfo to be used, as well as the evaluation context and a function
to extract the log density from the VarInfo. These must be known in order to
calculate the log density (using [`DynamicPPL.evaluate!!`](@ref)).

If the `adtype` keyword argument is provided, then this struct will also store
the adtype along with other information for efficient calculation of the
Expand Down Expand Up @@ -73,13 +74,13 @@
1

julia> # By default it uses `VarInfo` under the hood, but this is not necessary.
f = LogDensityFunction(model, SimpleVarInfo(model));
f = LogDensityFunction(model, getlogjoint, SimpleVarInfo(model));

julia> LogDensityProblems.logdensity(f, [0.0])
-2.3378770664093453

julia> # LogDensityFunction respects the accumulators in VarInfo:
f_prior = LogDensityFunction(model, setaccs!!(VarInfo(model), (LogPriorAccumulator(),)));
julia> # One can also specify evaluating e.g. the log prior only:
f_prior = LogDensityFunction(model, getlogprior);

julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
true
Expand All @@ -94,11 +95,17 @@
```
"""
struct LogDensityFunction{
M<:Model,V<:AbstractVarInfo,C<:AbstractContext,AD<:Union{Nothing,ADTypes.AbstractADType}
M<:Model,
F<:Function,
V<:AbstractVarInfo,
C<:AbstractContext,
AD<:Union{Nothing,ADTypes.AbstractADType},
}
"model used for evaluation"
model::M
"varinfo used for evaluation"
"function to be called on `varinfo` to extract the log density. By default `getlogjoint`."
getlogdensity::F
"varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`."
varinfo::V
"context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable"
context::C
Expand All @@ -109,7 +116,8 @@

function LogDensityFunction(
model::Model,
varinfo::AbstractVarInfo=VarInfo(model),
getlogdensity::Function=getlogjoint,
varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity),
context::AbstractContext=leafcontext(model.context);
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
)
Expand All @@ -125,21 +133,28 @@
x = map(identity, varinfo[:])
if use_closure(adtype)
prep = DI.prepare_gradient(
x -> logdensity_at(x, model, varinfo, context), adtype, x
x -> logdensity_at(x, model, getlogdensity, varinfo, context), adtype, x
)
else
prep = DI.prepare_gradient(
logdensity_at,
adtype,
x,
DI.Constant(model),
DI.Constant(getlogdensity),
DI.Constant(varinfo),
DI.Constant(context),
)
end
end
return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}(
model, varinfo, context, adtype, prep
return new{
typeof(model),
typeof(getlogdensity),
typeof(varinfo),
typeof(context),
typeof(adtype),
}(
model, getlogdensity, varinfo, context, adtype, prep
)
end
end
Expand All @@ -164,64 +179,87 @@
end
end

"""
ldf_default_varinfo(model::Model, getlogdensity::Function)

Create the default AbstractVarInfo that should be used for evaluating the log density.

Only the accumulators necesessary for `getlogdensity` will be used.
"""
function ldf_default_varinfo(::Model, getlogdensity::Function)
msg = """

Check warning on line 190 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L189-L190

Added lines #L189 - L190 were not covered by tests
LogDensityFunction does not know what sort of VarInfo should be used when \
`getlogdensity` is $getlogdensity. Please specify a VarInfo explicitly.
"""
return error(msg)

Check warning on line 194 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L194

Added line #L194 was not covered by tests
end

ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) = VarInfo(model)

function ldf_default_varinfo(model::Model, ::typeof(getlogprior))
return setaccs!!(VarInfo(model), (LogPriorAccumulator(),))
end

function ldf_default_varinfo(model::Model, ::typeof(getloglikelihood))
return setaccs!!(VarInfo(model), (LogLikelihoodAccumulator(),))
end

"""
logdensity_at(
x::AbstractVector,
model::Model,
getlogdensity::Function,
varinfo::AbstractVarInfo,
context::AbstractContext
)

Evaluate the log density of the given `model` at the given parameter values `x`,
using the given `varinfo` and `context`. Note that the `varinfo` argument is provided
only for its structure, in the sense that the parameters from the vector `x` are inserted
into it, and its own parameters are discarded. It does, however, determine whether the log
prior, likelihood, or joint is returned, based on which accumulators are set in it.
into it, and its own parameters are discarded. `getlogdensity` is the function that extracts
the log density from the evaluated varinfo.
"""
function logdensity_at(
x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext
x::AbstractVector,
model::Model,
getlogdensity::Function,
varinfo::AbstractVarInfo,
context::AbstractContext,
)
varinfo_new = unflatten(varinfo, x)
varinfo_eval = last(evaluate!!(model, varinfo_new, context))
has_prior = hasacc(varinfo_eval, Val(:LogPrior))
has_likelihood = hasacc(varinfo_eval, Val(:LogLikelihood))
if has_prior && has_likelihood
return getlogjoint(varinfo_eval)
elseif has_prior
return getlogprior(varinfo_eval)
elseif has_likelihood
return getloglikelihood(varinfo_eval)
else
error("LogDensityFunction: varinfo tracks neither log prior nor log likelihood")
end
return getlogdensity(varinfo_eval)
end

### LogDensityProblems interface

function LogDensityProblems.capabilities(
::Type{<:LogDensityFunction{M,V,C,Nothing}}
) where {M,V,C}
::Type{<:LogDensityFunction{M,F,V,C,Nothing}}
) where {M,F,V,C}
return LogDensityProblems.LogDensityOrder{0}()
end
function LogDensityProblems.capabilities(
::Type{<:LogDensityFunction{M,V,C,AD}}
) where {M,V,C,AD<:ADTypes.AbstractADType}
::Type{<:LogDensityFunction{M,F,V,C,AD}}
) where {M,F,V,C,AD<:ADTypes.AbstractADType}
return LogDensityProblems.LogDensityOrder{1}()
end
function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector)
return logdensity_at(x, f.model, f.varinfo, f.context)
return logdensity_at(x, f.model, f.getlogdensity, f.varinfo, f.context)
end
function LogDensityProblems.logdensity_and_gradient(
f::LogDensityFunction{M,V,C,AD}, x::AbstractVector
) where {M,V,C,AD<:ADTypes.AbstractADType}
f::LogDensityFunction{M,F,V,C,AD}, x::AbstractVector
) where {M,F,V,C,AD<:ADTypes.AbstractADType}
f.prep === nothing &&
error("Gradient preparation not available; this should not happen")
x = map(identity, x) # Concretise type
# Make branching statically inferrable, i.e. type-stable (even if the two
# branches happen to return different types)
return if use_closure(f.adtype)
DI.value_and_gradient(
x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x
x -> logdensity_at(x, f.model, f.getlogdensity, f.varinfo, f.context),

Check warning on line 259 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L259

Added line #L259 was not covered by tests
f.prep,
f.adtype,
x,
)
else
DI.value_and_gradient(
Expand All @@ -230,6 +268,7 @@
f.adtype,
x,
DI.Constant(f.model),
DI.Constant(f.getlogdensity),
DI.Constant(f.varinfo),
DI.Constant(f.context),
)
Expand Down Expand Up @@ -304,7 +343,7 @@
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
"""
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
return LogDensityFunction(model, f.varinfo, f.context; adtype=f.adtype)
return LogDensityFunction(model, f.getlogdensity, f.varinfo, f.context; adtype=f.adtype)
end

"""
Expand Down
8 changes: 6 additions & 2 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,9 @@
x = vi.values
y, logjac = with_logabsdet_jacobian(b, x)
vi_new = Accessors.@set(vi.values = y)
vi_new = acclogprior!!(vi_new, -logjac)
if hasacc(vi_new, Val(:LogPrior))
vi_new = acclogprior!!(vi_new, -logjac)

Check warning on line 610 in src/simple_varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/simple_varinfo.jl#L609-L610

Added lines #L609 - L610 were not covered by tests
end
return settrans!!(vi_new, t)
end

Expand All @@ -619,7 +621,9 @@
y = vi.values
x, logjac = with_logabsdet_jacobian(b, y)
vi_new = Accessors.@set(vi.values = x)
vi_new = acclogprior!!(vi_new, logjac)
if hasacc(vi_new, Val(:LogPrior))
vi_new = acclogprior!!(vi_new, logjac)

Check warning on line 625 in src/simple_varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/simple_varinfo.jl#L624-L625

Added lines #L624 - L625 were not covered by tests
end
return settrans!!(vi_new, NoTransformation())
end

Expand Down
8 changes: 5 additions & 3 deletions src/test_utils/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
using Chairmarks: @be
import DifferentiationInterface as DI
using DocStringExtensions
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint, link
using LogDensityProblems: logdensity, logdensity_and_gradient
using Random: Random, Xoshiro
using Statistics: median
Expand Down Expand Up @@ -184,7 +184,7 @@

verbose && @info "Running AD on $(model.f) with $(adtype)\n"
verbose && println(" params : $(params)")
ldf = LogDensityFunction(model, varinfo; adtype=adtype)
ldf = LogDensityFunction(model, getlogjoint, varinfo; adtype=adtype)

Check warning on line 187 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L187

Added line #L187 was not covered by tests

value, grad = logdensity_and_gradient(ldf, params)
grad = collect(grad)
Expand All @@ -193,7 +193,9 @@
if test
# Calculate ground truth to compare against
value_true, grad_true = if expected_value_and_grad === nothing
ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype)
ldf_reference = LogDensityFunction(

Check warning on line 196 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L196

Added line #L196 was not covered by tests
model, getlogjoint, varinfo; adtype=reference_adtype
)
logdensity_and_gradient(ldf_reference, params)
else
expected_value_and_grad
Expand Down
20 changes: 15 additions & 5 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1241,7 +1241,9 @@ function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f)
setrange!(md, vn, start:(start + length(yvec) - 1))
# Set the new value.
setval!(md, yvec, vn)
vi = acclogprior!!(vi, -logjac)
if hasacc(vi, Val(:LogPrior))
vi = acclogprior!!(vi, -logjac)
end
return vi
end

Expand Down Expand Up @@ -1278,7 +1280,9 @@ function _link(model::Model, varinfo::VarInfo, vns)
varinfo = deepcopy(varinfo)
md, logjac = _link_metadata!!(model, varinfo, varinfo.metadata, vns)
new_varinfo = VarInfo(md, varinfo.accs)
new_varinfo = acclogprior!!(new_varinfo, -logjac)
if hasacc(new_varinfo, Val(:LogPrior))
new_varinfo = acclogprior!!(new_varinfo, -logjac)
end
return new_varinfo
end

Expand All @@ -1292,7 +1296,9 @@ function _link(model::Model, varinfo::NTVarInfo, vns::NamedTuple)
varinfo = deepcopy(varinfo)
md, logjac = _link_metadata!(model, varinfo, varinfo.metadata, vns)
new_varinfo = VarInfo(md, varinfo.accs)
new_varinfo = acclogprior!!(new_varinfo, -logjac)
if hasacc(new_varinfo, Val(:LogPrior))
new_varinfo = acclogprior!!(new_varinfo, -logjac)
end
return new_varinfo
end

Expand Down Expand Up @@ -1441,7 +1447,9 @@ function _invlink(model::Model, varinfo::VarInfo, vns)
varinfo = deepcopy(varinfo)
md, logjac = _invlink_metadata!!(model, varinfo, varinfo.metadata, vns)
new_varinfo = VarInfo(md, varinfo.accs)
new_varinfo = acclogprior!!(new_varinfo, -logjac)
if hasacc(new_varinfo, Val(:LogPrior))
new_varinfo = acclogprior!!(new_varinfo, -logjac)
end
return new_varinfo
end

Expand All @@ -1455,7 +1463,9 @@ function _invlink(model::Model, varinfo::NTVarInfo, vns::NamedTuple)
varinfo = deepcopy(varinfo)
md, logjac = _invlink_metadata!(model, varinfo, varinfo.metadata, vns)
new_varinfo = VarInfo(md, varinfo.accs)
new_varinfo = acclogprior!!(new_varinfo, -logjac)
if hasacc(new_varinfo, Val(:LogPrior))
new_varinfo = acclogprior!!(new_varinfo, -logjac)
end
return new_varinfo
end

Expand Down
2 changes: 1 addition & 1 deletion src/varname.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ This is a very restricted version `subumes(u::VarName, v::VarName)` only really
- Scalar: `x` subsumes `x[1, 2]`, `x[1, 2]` subsumes `x[1, 2][3]`, etc.

## Note
- To get same matching capabilities as `AbstractPPL.subumes(u::VarName, v::VarName)`
- To get same matching capabilities as `AbstractPPL.subumes(u::VarName, v::VarName)`
for strings, one can always do `eval(varname(Meta.parse(u))` to get `VarName` of `u`,
and similarly to `v`. But this is slow.
"""
Expand Down
12 changes: 9 additions & 3 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ using DynamicPPL: LogDensityFunction

@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
linked_varinfo = DynamicPPL.link(varinfo, m)
f = LogDensityFunction(m, linked_varinfo)
f = LogDensityFunction(m, getlogjoint, linked_varinfo)
x = DynamicPPL.getparams(f)
# Calculate reference logp + gradient of logp using ForwardDiff
ref_ldf = LogDensityFunction(m, linked_varinfo; adtype=ref_adtype)
ref_ldf = LogDensityFunction(
m, getlogjoint, linked_varinfo; adtype=ref_adtype
)
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)

@testset "$adtype" for adtype in test_adtypes
Expand Down Expand Up @@ -106,7 +108,11 @@ using DynamicPPL: LogDensityFunction
spl = Sampler(MyEmptyAlg())
vi = VarInfo(model)
ldf = LogDensityFunction(
model, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true)
model,
getlogjoint,
vi,
SamplingContext(spl);
adtype=AutoReverseDiff(; compile=true),
)
@test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any
end
Expand Down
Loading
Loading