Skip to content

Remove context from model evaluation (use model.context instead) #952

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,42 @@ This release overhauls how VarInfo objects track variables such as the log joint
- `getlogp` now returns a `NamedTuple` with keys `logprior` and `loglikelihood`. If you want the log joint probability, which is what `getlogp` used to return, use `getlogjoint`.
- Correspondingly `setlogp!!` and `acclogp!!` should now be called with a `NamedTuple` with keys `logprior` and `loglikelihood`. The `acclogp!!` method with a single scalar value has been deprecated and falls back on `accloglikelihood!!`, and the single scalar version of `setlogp!!` has been removed. Corresponding setter/accumulator functions exist for the log prior as well.

### Evaluation contexts

Historically, evaluating a DynamicPPL model has required three arguments: a model, some kind of VarInfo, and a context.
It's less known, though, that since DynamicPPL 0.14.0 the _model_ itself actually contains a context as well.
This version therefore excises the context argument, and instead uses `model.context` as the evaluation context.

The upshot of this is that many functions that previously took a context argument now no longer do.
There were very few such functions where the context argument was actually used (most of them simply took `DefaultContext()` as the default value).

`evaluate!!(model, varinfo, ext_context)` is deprecated, and broadly speaking you should replace calls to that with `new_model = contextualize(model, ext_context); evaluate!!(new_model, varinfo)`.
If the 'external context' `ext_context` is a parent context, then you should wrap `model.context` appropriately to ensure that its information content is not lost.
If, on the other hand, `ext_context` is a `DefaultContext`, then you can just drop the argument entirely.

To aid with this process, `contextualize` is now exported from DynamicPPL.

The main situation where one _did_ want to specify an additional evaluation context was when that context was a `SamplingContext`.
Doing this would allow you to run the model and sample fresh values, instead of just using the values that existed in the VarInfo object.
Thus, this release also introduces the **unexported** function `evaluate_and_sample!!`.
Essentially, `evaluate_and_sample!!(rng, model, varinfo, sampler)` is a drop-in replacement for `evaluate!!(model, varinfo, SamplingContext(rng, sampler))`.
**Do note that this is an internal method**, and its name or semantics are liable to change in the future without warning.

There are many methods that no longer take a context argument, and listing them all would be too much.
However, here are the more user-facing ones:

- `LogDensityFunction` no longer has a context field (or type parameter)
- `DynamicPPL.TestUtils.AD.run_ad` no longer uses a context (and the returned `ADResult` object no longer has a context field)
- `VarInfo(rng, model, sampler)` and other VarInfo constructors / functions that made VarInfos (e.g. `typed_varinfo`) from a model
- `(::Model)(args...)`: specifically, this now only takes `rng` and `varinfo` arguments (with both being optional)
- If you are using the `__context__` special variable inside a model, you will now have to use `__model__.context` instead

And a couple of more internal changes:

- `evaluate!!`, `evaluate_threadsafe!!`, and `evaluate_threadunsafe!!` no longer accept context arguments
- `evaluate!!` no longer takes rng and sampler (if you used this, you should use `evaluate_and_sample!!` instead, or construct your own `SamplingContext`)
- The model evaluation function, `model.f` for some `model::Model`, no longer takes a context as an argument

## 0.36.12

Removed several unexported functions.
Expand Down
3 changes: 1 addition & 2 deletions benchmarks/src/DynamicPPLBenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,12 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
end

adbackend = to_backend(adbackend)
context = DynamicPPL.DefaultContext()

if islinked
vi = DynamicPPL.link(vi, model)
end

f = DynamicPPL.LogDensityFunction(model, vi, context; adtype=adbackend)
f = DynamicPPL.LogDensityFunction(model, vi; adtype=adbackend)
# The parameters at which we evaluate f.
θ = vi[:]

Expand Down
18 changes: 16 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ getargnames
getmissings
```

The context of a model can be set using [`contextualize`](@ref):

```@docs
contextualize
```

## Evaluation

With [`rand`](@ref) one can draw samples from the prior distribution of a [`Model`](@ref).
Expand Down Expand Up @@ -438,13 +444,21 @@ DynamicPPL.varname_and_value_leaves

### Evaluation Contexts

Internally, both sampling and evaluation of log densities are performed with [`AbstractPPL.evaluate!!`](@ref).
Internally, model evaluation is performed with [`AbstractPPL.evaluate!!`](@ref).

```@docs
AbstractPPL.evaluate!!
```

The behaviour of a model execution can be changed with evaluation contexts that are passed as additional argument to the model function.
This method mutates the `varinfo` used for execution.
By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`.
To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method:

```@docs
DynamicPPL.evaluate_and_sample!!
```

The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model.
Contexts are subtypes of `AbstractPPL.AbstractContext`.

```@docs
Expand Down
1 change: 0 additions & 1 deletion ext/DynamicPPLForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ function DynamicPPL.tweak_adtype(
ad::ADTypes.AutoForwardDiff{chunk_size},
::DynamicPPL.Model,
vi::DynamicPPL.AbstractVarInfo,
::DynamicPPL.AbstractContext,
) where {chunk_size}
params = vi[:]

Expand Down
22 changes: 11 additions & 11 deletions ext/DynamicPPLJETExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,10 @@ using DynamicPPL: DynamicPPL
using JET: JET

function DynamicPPL.Experimental.is_suitable_varinfo(
model::DynamicPPL.Model,
context::DynamicPPL.AbstractContext,
varinfo::DynamicPPL.AbstractVarInfo;
only_ddpl::Bool=true,
model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; only_ddpl::Bool=true
)
# Let's make sure that both evaluation and sampling doesn't result in type errors.
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
model, varinfo, context
)
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(model, varinfo)
# If specified, we only check errors originating somewhere in the DynamicPPL.jl.
# This way we don't just fall back to untyped if the user's code is the issue.
result = if only_ddpl
Expand All @@ -24,14 +19,19 @@ function DynamicPPL.Experimental.is_suitable_varinfo(
end

function DynamicPPL.Experimental._determine_varinfo_jet(
model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_ddpl::Bool=true
model::DynamicPPL.Model; only_ddpl::Bool=true
)
# Use SamplingContext to test type stability.
sampling_model = DynamicPPL.contextualize(
model, DynamicPPL.SamplingContext(model.context)
)

# First we try with the typed varinfo.
varinfo = DynamicPPL.typed_varinfo(model, context)
varinfo = DynamicPPL.typed_varinfo(sampling_model)

# Let's make sure that both evaluation and sampling doesn't result in type errors.
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo(
model, context, varinfo; only_ddpl
sampling_model, varinfo; only_ddpl
)

if !issuccess
Expand All @@ -46,7 +46,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet(
else
# Warn the user that we can't use the type stable one.
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
DynamicPPL.untyped_varinfo(model, context)
DynamicPPL.untyped_varinfo(sampling_model)
end
end

Expand Down
2 changes: 1 addition & 1 deletion ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ function DynamicPPL.predict(
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
predictive_samples = map(iters) do (sample_idx, chain_idx)
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
model(rng, varinfo, DynamicPPL.SampleFromPrior())
varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo))

vals = DynamicPPL.values_as_in_model(model, false, varinfo)
varname_vals = mapreduce(
Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ export AbstractVarInfo,
# LogDensityFunction
LogDensityFunction,
# Contexts
contextualize,
SamplingContext,
DefaultContext,
PrefixContext,
Expand Down
30 changes: 15 additions & 15 deletions src/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
const INTERNALNAMES = (:__model__, :__context__, :__varinfo__)
const INTERNALNAMES = (:__model__, :__varinfo__)

"""
need_concretize(expr)
Expand Down Expand Up @@ -63,9 +63,9 @@
function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr))
return quote
if $(DynamicPPL.contextual_isassumption)(
__context__, $(DynamicPPL.prefix)(__context__, $vn)
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
)
# Considered an assumption by `__context__` which means either:
# Considered an assumption by `__model__.context` which means either:
# 1. We hit the default implementation, e.g. using `DefaultContext`,
# which in turn means that we haven't considered if it's one of
# the model arguments, hence we need to check this.
Expand Down Expand Up @@ -116,7 +116,7 @@
isfixed(expr, vn) = false
function isfixed(::Union{Symbol,Expr}, vn)
return :($(DynamicPPL.contextual_isfixed)(
__context__, $(DynamicPPL.prefix)(__context__, $vn)
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
))
end

Expand Down Expand Up @@ -417,7 +417,7 @@
return quote
$right_val = $right
if $(DynamicPPL.is_extracting_values)(__varinfo__)
$vn = $(DynamicPPL.prefix)(__context__, $(make_varname_expression(left)))
$vn = $(DynamicPPL.prefix)(__model__.context, $(make_varname_expression(left)))

Check warning on line 420 in src/compiler.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler.jl#L420

Added line #L420 was not covered by tests
__varinfo__ = $(map_accumulator!!)(
$acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel)
)
Expand All @@ -431,7 +431,11 @@
@gensym value
return quote
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, nothing, __varinfo__
__model__.context,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
nothing,
__varinfo__,
)
$value
end
Expand All @@ -456,20 +460,20 @@
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $(DynamicPPL.isfixed(left, vn))
$left = $(DynamicPPL.getfixed_nested)(
__context__, $(DynamicPPL.prefix)(__context__, $vn)
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
)
elseif $isassumption
$(generate_tilde_assume(left, dist, vn))
else
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
if !$(DynamicPPL.inargnames)($vn, __model__)
$left = $(DynamicPPL.getconditioned_nested)(
__context__, $(DynamicPPL.prefix)(__context__, $vn)
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
)
end

$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
__context__,
__model__.context,
$(DynamicPPL.check_tilde_rhs)($dist),
$(maybe_view(left)),
$vn,
Expand All @@ -494,7 +498,7 @@

return quote
$value, __varinfo__ = $(DynamicPPL.tilde_assume!!)(
__context__,
__model__.context,
$(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)...,
__varinfo__,
)
Expand Down Expand Up @@ -652,11 +656,7 @@

# Add the internal arguments to the user-specified arguments (positional + keywords).
evaluatordef[:args] = vcat(
[
:(__model__::$(DynamicPPL.Model)),
:(__varinfo__::$(DynamicPPL.AbstractVarInfo)),
:(__context__::$(DynamicPPL.AbstractContext)),
],
[:(__model__::$(DynamicPPL.Model)), :(__varinfo__::$(DynamicPPL.AbstractVarInfo))],
args,
)

Expand Down
Loading
Loading