Skip to content

sample with LogDensityFunction: part 2 - ess.jl + mh.jl #2590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: py/ldf-hmc
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions src/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@

# Because this is a pain to implement all at once, we do it for one sampler at a time.
# This type tells us which samplers have been 'updated' to the new interface.

const LDFCompatibleSampler = Union{Hamiltonian}
const LDFCompatibleSampler = Union{Hamiltonian,ESS,MH}

"""
sample(
Expand Down Expand Up @@ -80,7 +79,7 @@
ctx = if ldf.context isa SamplingContext
ldf.context
else
SamplingContext(rng, spl)
SamplingContext(rng, spl, ldf.context)

Check warning on line 82 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L82

Added line #L82 was not covered by tests
end
Comment on lines 79 to 83
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

otherwise the existing context won't be obeyed

# Note that, in particular, sampling can mutate the variables in the LDF's
# varinfo (because it ultimately ends up calling `evaluate!!(ldf.model,
Expand Down Expand Up @@ -164,7 +163,7 @@
ctx = if ldf.context isa SamplingContext
ldf.context
else
SamplingContext(rng, spl)
SamplingContext(rng, spl, ldf.context)

Check warning on line 166 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L166

Added line #L166 was not covered by tests
end
# Note that, in particular, sampling can mutate the variables in the LDF's
# varinfo (because it ultimately ends up calling `evaluate!!(ldf.model,
Expand Down Expand Up @@ -282,7 +281,7 @@
initial_params = get(kwargs, :initial_params, nothing)
link = requires_unconstrained_space(spl)
vi = initialise_varinfo(rng, model, spl, initial_params, link)
ctx = SamplingContext(rng, spl)
ctx = SamplingContext(rng, spl, model.context)

Check warning on line 284 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L284

Added line #L284 was not covered by tests
ldf = LogDensityFunction(model, vi, ctx; adtype=get_adtype(spl))
# No need to run check_model again
return AbstractMCMC.sample(rng, ldf, spl, N; kwargs..., check_model=false)
Expand Down Expand Up @@ -331,7 +330,7 @@
initial_params = get(kwargs, :initial_params, nothing)
link = requires_unconstrained_space(spl)
vi = initialise_varinfo(rng, model, spl, initial_params, link)
ctx = SamplingContext(rng, spl)
ctx = SamplingContext(rng, spl, model.context)

Check warning on line 333 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L333

Added line #L333 was not covered by tests
ldf = LogDensityFunction(model, vi, ctx; adtype=get_adtype(spl))
# No need to run check_model again
return AbstractMCMC.sample(
Expand Down
73 changes: 33 additions & 40 deletions src/mcmc/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,26 @@
│ 1 │ m │ 0.824853 │
```
"""
struct ESS <: InferenceAlgorithm end
struct ESS <: AbstractSampler end

DynamicPPL.initialsampler(::ESS) = DynamicPPL.SampleFromPrior()
update_sample_kwargs(::ESS, ::Integer, kwargs) = kwargs
get_adtype(::ESS) = nothing
requires_unconstrained_space(::ESS) = false

Check warning on line 28 in src/mcmc/ess.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/ess.jl#L25-L28

Added lines #L25 - L28 were not covered by tests

# always accept in the first step
function DynamicPPL.initialstep(
rng::AbstractRNG, model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs...
)
function AbstractMCMC.step(rng::AbstractRNG, ldf::LogDensityFunction, spl::ESS; kwargs...)
vi = ldf.varinfo

Check warning on line 32 in src/mcmc/ess.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/ess.jl#L31-L32

Added lines #L31 - L32 were not covered by tests
for vn in keys(vi)
dist = getdist(vi, vn)
EllipticalSliceSampling.isgaussian(typeof(dist)) ||
error("ESS only supports Gaussian prior distributions")
end
return Transition(model, vi), vi
return Transition(ldf.model, vi), vi

Check warning on line 38 in src/mcmc/ess.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/ess.jl#L38

Added line #L38 was not covered by tests
end

function AbstractMCMC.step(
rng::AbstractRNG, model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs...
rng::AbstractRNG, ldf::LogDensityFunction, spl::ESS, vi::AbstractVarInfo; kwargs...
)
# obtain previous sample
f = vi[:]
Expand All @@ -45,14 +49,13 @@
oldstate = EllipticalSliceSampling.ESSState(f, getlogp(vi), nothing)

# compute next state
# Note: `f_loglikelihood` effectively calculates the log-likelihood (not
# log-joint, despite the use of `LDP.logdensity`) because `tilde_assume` is
# overloaded on `SamplingContext(rng, ESS(), ...)` below.
f_loglikelihood = Base.Fix1(LogDensityProblems.logdensity, ldf)

Check warning on line 55 in src/mcmc/ess.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/ess.jl#L55

Added line #L55 was not covered by tests
sample, state = AbstractMCMC.step(
rng,
EllipticalSliceSampling.ESSModel(
ESSPrior(model, spl, vi),
DynamicPPL.LogDensityFunction(
model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext())
),
),
EllipticalSliceSampling.ESSModel(ESSPrior(ldf.model, spl, vi), f_loglikelihood),
EllipticalSliceSampling.ESS(),
oldstate,
)
Expand All @@ -61,67 +64,57 @@
vi = DynamicPPL.unflatten(vi, sample)
vi = setlogp!!(vi, state.loglikelihood)

return Transition(model, vi), vi
return Transition(ldf.model, vi), vi

Check warning on line 67 in src/mcmc/ess.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/ess.jl#L67

Added line #L67 was not covered by tests
end

# Prior distribution of considered random variable
struct ESSPrior{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T}
struct ESSPrior{M<:Model,V<:AbstractVarInfo,T}
model::M
sampler::S
sampler::ESS
varinfo::V
μ::T

function ESSPrior{M,S,V}(
model::M, sampler::S, varinfo::V
) where {M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo}
function ESSPrior{M,V}(

Check warning on line 77 in src/mcmc/ess.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/ess.jl#L77

Added line #L77 was not covered by tests
model::M, sampler::ESS, varinfo::V
) where {M<:Model,V<:AbstractVarInfo}
vns = keys(varinfo)
μ = mapreduce(vcat, vns) do vn
dist = getdist(varinfo, vn)
EllipticalSliceSampling.isgaussian(typeof(dist)) ||
error("[ESS] only supports Gaussian prior distributions")
DynamicPPL.tovec(mean(dist))
end
return new{M,S,V,typeof(μ)}(model, sampler, varinfo, μ)
return new{M,V,typeof(μ)}(model, sampler, varinfo, μ)

Check warning on line 87 in src/mcmc/ess.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/ess.jl#L87

Added line #L87 was not covered by tests
end
end

function ESSPrior(model::Model, sampler::Sampler{<:ESS}, varinfo::AbstractVarInfo)
return ESSPrior{typeof(model),typeof(sampler),typeof(varinfo)}(model, sampler, varinfo)
function ESSPrior(model::Model, sampler::ESS, varinfo::AbstractVarInfo)
return ESSPrior{typeof(model),typeof(varinfo)}(model, sampler, varinfo)

Check warning on line 92 in src/mcmc/ess.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/ess.jl#L91-L92

Added lines #L91 - L92 were not covered by tests
end

# Ensure that the prior is a Gaussian distribution (checked in the constructor)
EllipticalSliceSampling.isgaussian(::Type{<:ESSPrior}) = true

# Only define out-of-place sampling
function Base.rand(rng::Random.AbstractRNG, p::ESSPrior)
sampler = p.sampler
varinfo = p.varinfo
# TODO: Surely there's a better way of doing this now that we have `SamplingContext`?
vns = keys(varinfo)
for vn in vns
set_flag!(varinfo, vn, "del")
# TODO(penelopeysm): This is ugly -- need to set 'del' flag because
# otherwise DynamicPPL.SampleWithPrior will just use the existing
# parameters in the varinfo. In general SampleWithPrior etc. need to be
# reworked.
for vn in keys(p.varinfo)
set_flag!(p.varinfo, vn, "del")

Check warning on line 105 in src/mcmc/ess.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/ess.jl#L104-L105

Added lines #L104 - L105 were not covered by tests
end
p.model(rng, varinfo, sampler)
return varinfo[:]
_, vi = DynamicPPL.evaluate!!(p.model, p.varinfo, SamplingContext(rng, p.sampler))
return vi[:]

Check warning on line 108 in src/mcmc/ess.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/ess.jl#L107-L108

Added lines #L107 - L108 were not covered by tests
end

# Mean of prior distribution
Distributions.mean(p::ESSPrior) = p.μ

# Evaluate log-likelihood of proposals
const ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} =
DynamicPPL.LogDensityFunction{M,V,<:DynamicPPL.SamplingContext{<:S},AD} where {AD}

(ℓ::ESSLogLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ, f)

function DynamicPPL.tilde_assume(
rng::Random.AbstractRNG, ::DefaultContext, ::Sampler{<:ESS}, right, vn, vi
rng::Random.AbstractRNG, ::DefaultContext, ::ESS, right, vn, vi
)
return DynamicPPL.tilde_assume(
rng, LikelihoodContext(), SampleFromPrior(), right, vn, vi
)
end

function DynamicPPL.tilde_observe(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vi)
return DynamicPPL.tilde_observe(ctx, SampleFromPrior(), right, left, vi)
end
Loading
Loading