Skip to content

Commit f11100e

Browse files
committed
Move DPPL submodel/condition docs to here
1 parent 6baaac3 commit f11100e

File tree

4 files changed

+361
-1
lines changed

4 files changed

+361
-1
lines changed

Manifest.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.11.5"
44
manifest_format = "2.0"
5-
project_hash = "abb3c770eb08cdd80c9627a8a7c327584291f4c8"
5+
project_hash = "93e3f90921e5771e56e7b8d21131b1107faa4765"
66

77
[[deps.ADTypes]]
88
git-tree-sha1 = "e2478490447631aedba0823d4d7a80b2cc8cdb32"

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
44
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
5+
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
56
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
67
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
78
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"

_quarto.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ website:
9898
- developers/compiler/minituring-contexts/index.qmd
9999
- developers/compiler/design-overview/index.qmd
100100

101+
- section: "DynamicPPL Contexts"
102+
collapse-level: 1
103+
contents:
104+
- developers/contexts/submodel-condition/index.qmd
105+
101106
- section: "Variable Transformations"
102107
collapse-level: 1
103108
contents:
@@ -204,3 +209,4 @@ using-turing-implementing-samplers: developers/inference/implementing-samplers
204209
dev-transforms-distributions: developers/transforms/distributions
205210
dev-transforms-bijectors: developers/transforms/bijectors
206211
dev-transforms-dynamicppl: developers/transforms/dynamicppl
212+
dev-contexts-submodel-condition: developers/contexts/submodel-condition
Lines changed: 353 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,353 @@
1+
---
2+
title: "Conditioning and fixing in submodels"
3+
engine: julia
4+
---
5+
6+
## PrefixContext
7+
8+
Submodels in DynamicPPL come with the notion of _prefixing_ variables: under the hood, this is implemented by adding a `PrefixContext` to the context stack.
9+
10+
`PrefixContext` is a context that, as the name suggests, prefixes all variables inside a model with a given symbol.
11+
Thus, for example:
12+
13+
```{julia}
14+
using DynamicPPL, Distributions
15+
16+
@model function f()
17+
x ~ Normal()
18+
return y ~ Normal()
19+
end
20+
21+
@model function g()
22+
return a ~ to_submodel(f())
23+
end
24+
```
25+
26+
inside the submodel `f`, the variables `x` and `y` become `a.x` and `a.y` respectively.
27+
This is easiest to observe by running the model:
28+
29+
```{julia}
30+
vi = VarInfo(g())
31+
keys(vi)
32+
```
33+
34+
::: {.callout-note}
35+
In this case, where `to_submodel` is called without any other arguments, the prefix to be used is automatically inferred from the name of the variable on the left-hand side of the tilde.
36+
We will return to the 'manual prefixing' case later.
37+
:::
38+
39+
The phrase 'becoming' a different variable is a little underspecified: it is useful to pinpoint the exact location where the prefixing occurs, which is `tilde_assume`.
40+
The method responsible for it is `tilde_assume(::PrefixContext, right, vn, vi)`: this attaches the prefix in the context to the `VarName` argument, before recursively calling `tilde_assume` with the new prefixed `VarName`.
41+
This means that even though a statement `x ~ dist` still enters the tilde pipeline at the top level as `x`, if the model evaluation context contains a `PrefixContext`, any function after `tilde_assume(::PrefixContext, ...)` will see `a.x` instead.
42+
43+
## ConditionContext
44+
45+
`ConditionContext` is a context which stores values of variables that are to be conditioned on.
46+
These values may be stored as a `Dict` which maps `VarName`s to values, or alternatively as a `NamedTuple`.
47+
The latter only works correctly if all `VarName`s are 'basic', in that they have an identity optic (i.e., something like `a.x` or `a[1]` is forbidden).
48+
Because of this limitation, we will only use `Dict` in this example.
49+
50+
::: {.callout-note}
51+
If a `ConditionContext` with a `NamedTuple` encounters anything to do with a prefix, its internal `NamedTuple` is converted to a `Dict` anyway, so it is quite reasonable to ignore the `NamedTuple` case in this exposition.
52+
:::
53+
54+
One can inspect the conditioning values with, for example:
55+
56+
```{julia}
57+
@model function d()
58+
x ~ Normal()
59+
return y ~ Normal()
60+
end
61+
62+
cond_model = d() | (@varname(x) => 1.0)
63+
cond_ctx = cond_model.context
64+
```
65+
66+
There are several internal functions that are used to determine whether a variable is conditioned, and if so, what its value is.
67+
68+
```{julia}
69+
DynamicPPL.hasconditioned_nested(cond_ctx, @varname(x))
70+
```
71+
72+
```{julia}
73+
DynamicPPL.getconditioned_nested(cond_ctx, @varname(x))
74+
```
75+
76+
These functions are in turn used by the function `DynamicPPL.contextual_isassumption`, which is largely the same as `hasconditioned_nested`, but also checks whether the value is `missing` (in which case it isn't really conditioned).
77+
78+
```{julia}
79+
DynamicPPL.contextual_isassumption(cond_ctx, @varname(x))
80+
```
81+
82+
::: {.callout-note}
83+
Notice that (neglecting `missing` values) the return value of `contextual_isassumption` is the _opposite_ of `hasconditioned_nested`, i.e. for a variable that _is_ conditioned on, `contextual_isassumption` returns `false`.
84+
:::
85+
86+
If a variable `x` is conditioned on, then the effect of this is to set the value of `x` to the given value (while still including its contribution to the log probability density).
87+
Since `x` is no longer a random variable, if we were to evaluate the model, we would find only one key in the `VarInfo`:
88+
89+
```{julia}
90+
keys(VarInfo(cond_model))
91+
```
92+
93+
## Joint behaviour: desiderata at the model level
94+
95+
When paired together, these two contexts have the potential to cause substantial confusion: `PrefixContext` modifies the variable names that are seen, which may cause them to be out of sync with the values contained inside the `ConditionContext`.
96+
97+
We begin by mentioning some high-level desiderata for their joint behaviour.
98+
Take these models, for example:
99+
100+
```{julia}
101+
# We define a helper function to unwrap a layer of SamplingContext, to
102+
# avoid cluttering the print statements.
103+
unwrap_sampling_context(ctx::DynamicPPL.SamplingContext) = ctx.context
104+
unwrap_sampling_context(ctx::DynamicPPL.AbstractContext) = ctx
105+
106+
@model function inner()
107+
println("inner context: $(unwrap_sampling_context(__context__))")
108+
x ~ Normal()
109+
return y ~ Normal()
110+
end
111+
112+
@model function outer()
113+
println("outer context: $(unwrap_sampling_context(__context__))")
114+
return a ~ to_submodel(inner())
115+
end
116+
117+
# 'Outer conditioning'
118+
with_outer_cond = outer() | (@varname(a.x) => 1.0)
119+
120+
# 'Inner conditioning'
121+
inner_cond = inner() | (@varname(x) => 1.0)
122+
@model function outer2()
123+
println("outer context: $(unwrap_sampling_context(__context__))")
124+
return a ~ to_submodel(inner_cond)
125+
end
126+
with_inner_cond = outer2()
127+
```
128+
129+
We want that:
130+
131+
1. `keys(VarInfo(outer()))` should return `[a.x, a.y]`;
132+
2. `keys(VarInfo(with_outer_cond))` should return `[a.y]`;
133+
3. `keys(VarInfo(with_inner_cond))` should return `[a.y]`,
134+
135+
**In other words, we can condition submodels either from the outside (point (2)) or from the inside (point (3)), and the variable name we use to specify the conditioning should match the level at which we perform the conditioning.**
136+
137+
This is an incredibly salient point because it means that submodels can be treated as individual, opaque objects, and we can condition them without needing to know what it will be prefixed with, or the context in which that submodel is being used.
138+
For example, this means we can reuse `inner_cond` in another model with a different prefix, and it will _still_ have its inner `x` value be conditioned, despite the prefix differing.
139+
140+
::: {.callout-note}
141+
In the current version of DynamicPPL, these criteria are all fulfilled.
142+
However, this was not the case in the past: in particular, point (3) was not fulfilled, and users had to condition the internal submodel with the prefixes that were used outside.
143+
(See [this GitHub issue](https://github.com/TuringLang/DynamicPPL.jl/issues/857) for more information; this issue was the direct motivation for this documentation page.)
144+
:::
145+
146+
## Desiderata at the context level
147+
148+
The above section describes how we expect conditioning and prefixing to behave from a user's perpective.
149+
We now turn to the question of how we implement this in terms of DynamicPPL contexts.
150+
We do not specify the implementation details here, but we will sketch out something resembling an API that will allow us to achieve the target behaviour.
151+
152+
**Point (1)** does not involve any conditioning, only prefixing; it is therefore already satisfied by virtue of the `tilde_assume` method shown above.
153+
154+
**Points (2) and (3)** are more tricky.
155+
As the reader may surmise, the difference between them is the order in which the contexts are stacked.
156+
157+
For the _outer_ conditioning case (point (2)), the `ConditionContext` will contain a `VarName` that is already prefixed.
158+
When we enter the inner submodel, this `ConditionContext` has to be passed down and somehow combined with the `PrefixContext` that is created when we enter the submodel.
159+
We make the claim here that the best way to do this is to nest the `PrefixContext` _inside_ the `ConditionContext`.
160+
This is indeed what happens, as can be demonstrated by running the model.
161+
162+
```{julia}
163+
with_outer_cond()
164+
```
165+
166+
For the _inner_ conditioning case (point (3)), the outer model is not run with any special context.
167+
The inner model will itself contain a `ConditionContext` will contain a `VarName` that is not prefixed.
168+
When we run the model, this `ConditionContext` should be then nested _inside_ a `PrefixContext` to form the final evaluation context.
169+
Again, we can run the model to see this in action:
170+
171+
```{julia}
172+
with_inner_cond()
173+
```
174+
175+
Putting all of the information so far together, what it means is that if we have these two inner contexts (taken from above):
176+
177+
```{julia}
178+
using DynamicPPL: PrefixContext, ConditionContext, DefaultContext
179+
180+
inner_ctx_with_outer_cond = ConditionContext(
181+
Dict(@varname(a.x) => 1.0), PrefixContext(@varname(a))
182+
)
183+
inner_ctx_with_inner_cond = PrefixContext(
184+
@varname(a), ConditionContext(Dict(@varname(x) => 1.0))
185+
)
186+
```
187+
188+
then we want both of these to be `true` (and thankfully, they are!):
189+
190+
```{julia}
191+
DynamicPPL.hasconditioned_nested(inner_ctx_with_outer_cond, @varname(a.x))
192+
```
193+
194+
```{julia}
195+
DynamicPPL.hasconditioned_nested(inner_ctx_with_inner_cond, @varname(a.x))
196+
```
197+
198+
This allows us to finally specify our task as follows:
199+
200+
(1) Given the correct arguments, we need to make sure that `hasconditioned_nested` and `getconditioned_nested` behave correctly.
201+
202+
(2) We need to make sure that both the correct arguments are supplied. In order to do so:
203+
204+
- (2a) We need to make sure that when evaluating a submodel, the context stack is arranged such that `PrefixContext` is applied _inside_ the parent model's context, but _outside_ the submodel's own context.
205+
206+
- (2b) We also need to make sure that the `VarName` passed to it is prefixed correctly.
207+
208+
## How do we do it?
209+
210+
(1) `hasconditioned_nested` and `getconditioned_nested` accomplish this by first 'collapsing' the context stack, i.e. they go through the context stack, remove all `PrefixContext`s, and apply those prefixes to any conditioned variables below it in the stack.
211+
Once the `PrefixContext`s have been removed, one can then iterate through the context stack and check if any of the `ConditionContext`s contain the variable, or get the value itself.
212+
For more details the reader is encouraged to read the source code.
213+
214+
(2a) We ensure that the context stack is correctly arranged by relying on the behaviour of `make_evaluate_args_and_kwargs`.
215+
This function is called whenever a model (which itself contains a context) is evaluated with a separate ('external') context, and makes sure to arrange both of these contexts such that _the model's context is nested inside the external context_.
216+
Thus, as long as prefixing is implemented by applying a `PrefixContext` on the outermost layer of the _inner_ model context, this will be correctly combined with an external context to give the behaviour seen above.
217+
218+
(2b) At first glance, it seems like `tilde_assume` can take care of the `VarName` prefixing for us (as described in the first section).
219+
However, this is not actually the case: `contextual_isassumption`, which is the function that calls `hasconditioned_nested`, is much higher in the call stack than `tilde_assume` is.
220+
So, we need to explicitly prefix it before passing it to `contextual_isassumption`.
221+
This is done inside the `@model` macro, or technically, its subsidiary function `isassumption`.
222+
223+
## Nested submodels
224+
225+
Just in case the above wasn't complicated enough, we need to also be very careful when dealing with nested submodels, which have multiple layers of `PrefixContext`s which may be interspersed with `ConditionContext`s.
226+
For example, in this series of nested submodels,
227+
228+
```{julia}
229+
@model function charlie()
230+
x ~ Normal()
231+
y ~ Normal()
232+
return z ~ Normal()
233+
end
234+
@model function bravo()
235+
return b ~ to_submodel(charlie() | (@varname(x) => 1.0))
236+
end
237+
@model function alpha()
238+
return a ~ to_submodel(bravo() | (@varname(b.y) => 1.0))
239+
end
240+
```
241+
242+
we expect that the only variable to be sampled should be `z` inside `charlie`, or rather, `a.b.z` once it has been through the prefixes.
243+
244+
```{julia}
245+
keys(VarInfo(alpha()))
246+
```
247+
248+
The general strategy that we adopt is similar to above.
249+
Following the principle that `PrefixContext` should be nested inside the outer context, but outside the inner submodel's context, we can infer that the correct context inside `charlie` should be:
250+
251+
```{julia}
252+
big_ctx = PrefixContext(
253+
@varname(a),
254+
ConditionContext(
255+
Dict(@varname(b.y) => 1.0),
256+
PrefixContext(@varname(b), ConditionContext(Dict(@varname(x) => 1.0))),
257+
),
258+
)
259+
```
260+
261+
We need several things to work correctly here: we need the `VarName` prefixing to behave correctly, and then we need to implement `hasconditioned_nested` and `getconditioned_nested` on the resulting prefixed `VarName`.
262+
It turns out that the prefixing itself is enough to illustrate the most important point in this section, namely, the need to traverse the context stack in a _different direction_ to what most of DynamicPPL does.
263+
264+
Let's work with a function called `myprefix(::AbstractContext, ::VarName)` (to avoid confusion with any existing DynamicPPL function).
265+
We should like `myprefix(big_ctx, @varname(x))` to return `@varname(a.b.x)`.
266+
Consider the following naive implementation, which mirrors a lot of code in the tilde-pipeline:
267+
268+
```{julia}
269+
using DynamicPPL: NodeTrait, IsLeaf, IsParent, childcontext, AbstractContext
270+
using AbstractPPL: AbstractPPL
271+
272+
function myprefix(ctx::DynamicPPL.AbstractContext, vn::VarName)
273+
return myprefix(NodeTrait(ctx), ctx, vn)
274+
end
275+
function myprefix(::IsLeaf, ::AbstractContext, vn::VarName)
276+
return vn
277+
end
278+
function myprefix(::IsParent, ctx::AbstractContext, vn::VarName)
279+
return myprefix(childcontext(ctx), vn)
280+
end
281+
function myprefix(ctx::DynamicPPL.PrefixContext, vn::VarName)
282+
# The functionality to actually manipulate the VarNames is in AbstractPPL
283+
new_vn = AbstractPPL.prefix(vn, ctx.vn_prefix)
284+
# Then pass to the child context
285+
return myprefix(childcontext(ctx), new_vn)
286+
end
287+
288+
myprefix(big_ctx, @varname(x))
289+
```
290+
291+
This implementation clearly is not correct, because it applies the _inner_ `PrefixContext` before the outer one.
292+
293+
The right way to implement `myprefix` is to, essentially, reverse the order of two lines above:
294+
295+
```{julia}
296+
function myprefix(ctx::DynamicPPL.PrefixContext, vn::VarName)
297+
# Pass to the child context first
298+
new_vn = myprefix(childcontext(ctx), vn)
299+
# Then apply this context's prefix
300+
return AbstractPPL.prefix(new_vn, ctx.vn_prefix)
301+
end
302+
303+
myprefix(big_ctx, @varname(x))
304+
```
305+
306+
This is a much better result!
307+
The implementation of related functions such as `hasconditioned_nested` and `getconditioned_nested`, under the hood, use a similar recursion scheme, so you will find that this is a common pattern when reading the source code of various prefixing-related functions.
308+
When editing this code, it is worth being mindful of this as a potential source of incorrectness.
309+
310+
::: {.callout-note}
311+
If you have encountered left and right folds, the above discussion illustrates the difference between them: the wrong implementation of `myprefix` uses a left fold (which collects prefixes in the opposite order from which they are encountered), while the correct implementation uses a right fold.
312+
:::
313+
314+
## Loose ends 1: Manual prefixing
315+
316+
Sometimes users may want to manually prefix a model, for example:
317+
318+
```{julia}
319+
@model function inner_manual()
320+
x ~ Normal()
321+
return y ~ Normal()
322+
end
323+
324+
@model function outer_manual()
325+
return _unused ~ to_submodel(prefix(inner_manual(), :a), false)
326+
end
327+
```
328+
329+
In this case, the `VarName` on the left-hand side of the tilde is not used, and the prefix is instead specified using the `prefix` function.
330+
331+
The way to deal with this follows on from the previous discussion.
332+
Specifically, we said that:
333+
334+
> [...] as long as prefixing is implemented by applying a `PrefixContext` on the outermost layer of the _inner_ model context, this will be correctly combined [...]
335+
336+
When automatic prefixing is used, this application of `PrefixContext` occurs inside the `tilde_assume!!` method.
337+
In the manual prefixing case, we need to make sure that `prefix(submodel::Model, ::Symbol)` does the same thing, i.e. it inserts a `PrefixContext` at the outermost layer of `submodel`'s context.
338+
We can see that this is precisely what happens:
339+
340+
```{julia}
341+
@model f() = x ~ Normal()
342+
343+
model = f()
344+
prefixed_model = prefix(model, :a)
345+
346+
(model.context, prefixed_model.context)
347+
```
348+
349+
## Loose ends 2: FixedContext
350+
351+
Finally, note that all of the above also applies to the interaction between `PrefixContext` and `FixedContext`, except that the functions have different names.
352+
(`FixedContext` behaves the same way as `ConditionContext`, except that unlike conditioned variables, fixed variables do not contribute to the log probability density.)
353+
This generally results in a large amount of code duplication, but the concepts that underlie both contexts are exactly the same.

0 commit comments

Comments
 (0)