Skip to content

[WIP] InitContext #967

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: breaking
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 6 additions & 15 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -440,33 +440,24 @@ AbstractPPL.evaluate!!

This method mutates the `varinfo` used for execution.
By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`.
To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method:

```@docs
DynamicPPL.evaluate_and_sample!!
```

The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model.
Contexts are subtypes of `AbstractPPL.AbstractContext`.

```@docs
SamplingContext
DefaultContext
PrefixContext
ConditionContext
InitContext
```

### Samplers
### VarInfo initialisation

In DynamicPPL two samplers are defined that are used to initialize unobserved random variables:
[`SampleFromPrior`](@ref) which samples from the prior distribution, and [`SampleFromUniform`](@ref) which samples from a uniform distribution.
TODO

```@docs
SampleFromPrior
SampleFromUniform
```
### Samplers

Additionally, a generic sampler for inference is implemented.
In DynamicPPL a generic sampler for inference is implemented.

```@docs
Sampler
Expand All @@ -477,7 +468,7 @@ The default implementation of [`Sampler`](@ref) uses the following unexported fu
```@docs
DynamicPPL.initialstep
DynamicPPL.loadstate
DynamicPPL.initialsampler
DynamicPPL.init_strategy
```

Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`.
Expand Down
2 changes: 0 additions & 2 deletions ext/DynamicPPLEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ else
using ..EnzymeCore
end

@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:DynamicPPL.SamplingContext}) = true

# Mark istrans as having 0 derivative. The `nothing` return value is not significant, Enzyme
# only checks whether such a method exists, and never runs it.
@inline EnzymeCore.EnzymeRules.inactive_noinl(::typeof(DynamicPPL.istrans), args...) =
Expand Down
11 changes: 3 additions & 8 deletions ext/DynamicPPLJETExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,12 @@
function DynamicPPL.Experimental._determine_varinfo_jet(
model::DynamicPPL.Model; only_ddpl::Bool=true
)
# Use SamplingContext to test type stability.
sampling_model = DynamicPPL.contextualize(
model, DynamicPPL.SamplingContext(model.context)
)

# First we try with the typed varinfo.
varinfo = DynamicPPL.typed_varinfo(sampling_model)
varinfo = DynamicPPL.typed_varinfo(model)

Check warning on line 25 in ext/DynamicPPLJETExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLJETExt.jl#L25

Added line #L25 was not covered by tests

# Let's make sure that both evaluation and sampling doesn't result in type errors.
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo(
sampling_model, varinfo; only_ddpl
model, varinfo; only_ddpl
)

if !issuccess
Expand All @@ -46,8 +41,8 @@
else
# Warn the user that we can't use the type stable one.
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
DynamicPPL.untyped_varinfo(sampling_model)
DynamicPPL.untyped_varinfo(model)
end
end

end
39 changes: 29 additions & 10 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@
return keys(c.info.varname_to_symbol)
end

function chain_sample_to_varname_dict(c::MCMCChains.Chains, sample_idx, chain_idx)
_check_varname_indexing(c)
d = Dict{DynamicPPL.VarName,Any}()
for vn in DynamicPPL.varnames(c)
d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx)
end
return d
end

"""
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)

Expand Down Expand Up @@ -114,9 +123,17 @@

iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
predictive_samples = map(iters) do (sample_idx, chain_idx)
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo))

# Extract values from the chain
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)

Check warning on line 127 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L127

Added line #L127 was not covered by tests
# Resample any variables that are not present in `values_dict`
_, varinfo = last(

Check warning on line 129 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L129

Added line #L129 was not covered by tests
DynamicPPL.init!!(
rng,
model,
varinfo,
DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()),
),
)
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
varname_vals = mapreduce(
collect,
Expand Down Expand Up @@ -248,13 +265,15 @@
varinfo = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
return map(iters) do (sample_idx, chain_idx)
# TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702.
# Update the varinfo with the current sample and make variables not present in `chain`
# to be sampled.
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
# NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
# `deepcopy` the `varinfo` before passing it to the `model`.
model(deepcopy(varinfo))
# Extract values from the chain
values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx)

Check warning on line 269 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L269

Added line #L269 was not covered by tests
# Resample any variables that are not present in `values_dict`, and
# return the model's retval (`first`).
first(

Check warning on line 272 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L272

Added line #L272 was not covered by tests
DynamicPPL.init!!(
model, varinfo, DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit())
),
)
end
end

Expand Down
10 changes: 6 additions & 4 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,14 @@ export AbstractVarInfo,
values_as_in_model,
# Samplers
Sampler,
SampleFromPrior,
SampleFromUniform,
# Initialisation strategies
PriorInit,
UniformInit,
ParamsInit,
# LogDensityFunction
LogDensityFunction,
# Contexts
contextualize,
SamplingContext,
DefaultContext,
PrefixContext,
ConditionContext,
Expand Down Expand Up @@ -170,11 +171,12 @@ abstract type AbstractVarInfo <: AbstractModelTrace end
# Necessary forward declarations
include("utils.jl")
include("chains.jl")
include("contexts.jl")
include("contexts/init.jl")
include("model.jl")
include("sampler.jl")
include("varname.jl")
include("distribution_wrappers.jl")
include("contexts.jl")
include("submodel.jl")
include("varnamedvector.jl")
include("accumulators.jl")
Expand Down
91 changes: 0 additions & 91 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
@@ -1,38 +1,11 @@
# assume
"""
tilde_assume(context::SamplingContext, right, vn, vi)

Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
accumulate the log probability, and return the sampled value with a context associated
with a sampler.

Falls back to
```julia
tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
```
"""
function tilde_assume(context::SamplingContext, right, vn, vi)
return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
end

function tilde_assume(context::AbstractContext, args...)
return tilde_assume(childcontext(context), args...)
end
function tilde_assume(::DefaultContext, right, vn, vi)
return assume(right, vn, vi)
end

function tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...)
return tilde_assume(rng, childcontext(context), args...)
end
function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi)
return assume(rng, sampler, right, vn, vi)
end
function tilde_assume(::DefaultContext, sampler, right, vn, vi)
# same as above but no rng
return assume(Random.default_rng(), sampler, right, vn, vi)
end

function tilde_assume(context::PrefixContext, right, vn, vi)
# Note that we can't use something like this here:
# new_vn = prefix(context, vn)
Expand All @@ -46,12 +19,6 @@ function tilde_assume(context::PrefixContext, right, vn, vi)
new_vn, new_context = prefix_and_strip_contexts(context, vn)
return tilde_assume(new_context, right, new_vn, vi)
end
function tilde_assume(
rng::Random.AbstractRNG, context::PrefixContext, sampler, right, vn, vi
)
new_vn, new_context = prefix_and_strip_contexts(context, vn)
return tilde_assume(rng, new_context, sampler, right, new_vn, vi)
end

"""
tilde_assume!!(context, right, vn, vi)
Expand All @@ -71,17 +38,6 @@ function tilde_assume!!(context, right, vn, vi)
end

# observe
"""
tilde_observe!!(context::SamplingContext, right, left, vi)

Handle observed constants with a `context` associated with a sampler.

Falls back to `tilde_observe!!(context.context, right, left, vi)`.
"""
function tilde_observe!!(context::SamplingContext, right, left, vn, vi)
return tilde_observe!!(context.context, right, left, vn, vi)
end

function tilde_observe!!(context::AbstractContext, right, left, vn, vi)
return tilde_observe!!(childcontext(context), right, left, vn, vi)
end
Expand Down Expand Up @@ -115,10 +71,6 @@ function tilde_observe!!(::DefaultContext, right, left, vn, vi)
return left, vi
end

function assume(::Random.AbstractRNG, spl::Sampler, dist)
return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))")
end

# fallback without sampler
function assume(dist::Distribution, vn::VarName, vi)
y = getindex_internal(vi, vn)
Expand All @@ -127,46 +79,3 @@ function assume(dist::Distribution, vn::VarName, vi)
vi = accumulate_assume!!(vi, x, logjac, vn, dist)
return x, vi
end

# TODO: Remove this thing.
# SampleFromPrior and SampleFromUniform
function assume(
rng::Random.AbstractRNG,
sampler::Union{SampleFromPrior,SampleFromUniform},
dist::Distribution,
vn::VarName,
vi::VarInfoOrThreadSafeVarInfo,
)
if haskey(vi, vn)
# Always overwrite the parameters with new ones for `SampleFromUniform`.
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure
# if that's okay.
unset_flag!(vi, vn, "del", true)
r = init(rng, dist, sampler)
f = to_maybe_linked_internal_transform(vi, vn, dist)
# TODO(mhauru) This should probably be call a function called setindex_internal!
vi = BangBang.setindex!!(vi, f(r), vn)
setorder!(vi, vn, get_num_produce(vi))
else
# Otherwise we just extract it.
r = vi[vn, dist]
end
else
r = init(rng, dist, sampler)
if istrans(vi)
f = to_linked_internal_transform(vi, vn, dist)
vi = push!!(vi, vn, f(r), dist)
# By default `push!!` sets the transformed flag to `false`.
vi = settrans!!(vi, true, vn)
else
vi = push!!(vi, vn, r, dist)
end
end

# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
vi = accumulate_assume!!(vi, r, -logjac, vn, dist)
return r, vi
end
Loading
Loading