|
| 1 | +--- |
| 2 | +title: "Conditioning and fixing in submodels" |
| 3 | +engine: julia |
| 4 | +--- |
| 5 | + |
| 6 | +## PrefixContext |
| 7 | + |
| 8 | +Submodels in DynamicPPL come with the notion of _prefixing_ variables: under the hood, this is implemented by adding a `PrefixContext` to the context stack. |
| 9 | + |
| 10 | +`PrefixContext` is a context that, as the name suggests, prefixes all variables inside a model with a given symbol. |
| 11 | +Thus, for example: |
| 12 | + |
| 13 | +```{julia} |
| 14 | +using DynamicPPL, Distributions |
| 15 | +
|
| 16 | +@model function f() |
| 17 | + x ~ Normal() |
| 18 | + return y ~ Normal() |
| 19 | +end |
| 20 | +
|
| 21 | +@model function g() |
| 22 | + return a ~ to_submodel(f()) |
| 23 | +end |
| 24 | +``` |
| 25 | + |
| 26 | +inside the submodel `f`, the variables `x` and `y` become `a.x` and `a.y` respectively. |
| 27 | +This is easiest to observe by running the model: |
| 28 | + |
| 29 | +```{julia} |
| 30 | +vi = VarInfo(g()) |
| 31 | +keys(vi) |
| 32 | +``` |
| 33 | + |
| 34 | +::: {.callout-note} |
| 35 | +In this case, where `to_submodel` is called without any other arguments, the prefix to be used is automatically inferred from the name of the variable on the left-hand side of the tilde. |
| 36 | +We will return to the 'manual prefixing' case later. |
| 37 | +::: |
| 38 | + |
| 39 | +The phrase 'becoming' a different variable is a little underspecified: it is useful to pinpoint the exact location where the prefixing occurs, which is `tilde_assume`. |
| 40 | +The method responsible for it is `tilde_assume(::PrefixContext, right, vn, vi)`: this attaches the prefix in the context to the `VarName` argument, before recursively calling `tilde_assume` with the new prefixed `VarName`. |
| 41 | +This means that even though a statement `x ~ dist` still enters the tilde pipeline at the top level as `x`, if the model evaluation context contains a `PrefixContext`, any function after `tilde_assume(::PrefixContext, ...)` will see `a.x` instead. |
| 42 | + |
| 43 | +## ConditionContext |
| 44 | + |
| 45 | +`ConditionContext` is a context which stores values of variables that are to be conditioned on. |
| 46 | +These values may be stored as a `Dict` which maps `VarName`s to values, or alternatively as a `NamedTuple`. |
| 47 | +The latter only works correctly if all `VarName`s are 'basic', in that they have an identity optic (i.e., something like `a.x` or `a[1]` is forbidden). |
| 48 | +Because of this limitation, we will only use `Dict` in this example. |
| 49 | + |
| 50 | +::: {.callout-note} |
| 51 | +If a `ConditionContext` with a `NamedTuple` encounters anything to do with a prefix, its internal `NamedTuple` is converted to a `Dict` anyway, so it is quite reasonable to ignore the `NamedTuple` case in this exposition. |
| 52 | +::: |
| 53 | + |
| 54 | +One can inspect the conditioning values with, for example: |
| 55 | + |
| 56 | +```{julia} |
| 57 | +@model function d() |
| 58 | + x ~ Normal() |
| 59 | + return y ~ Normal() |
| 60 | +end |
| 61 | +
|
| 62 | +cond_model = d() | (@varname(x) => 1.0) |
| 63 | +cond_ctx = cond_model.context |
| 64 | +``` |
| 65 | + |
| 66 | +There are several internal functions that are used to determine whether a variable is conditioned, and if so, what its value is. |
| 67 | + |
| 68 | +```{julia} |
| 69 | +DynamicPPL.hasconditioned_nested(cond_ctx, @varname(x)) |
| 70 | +``` |
| 71 | + |
| 72 | +```{julia} |
| 73 | +DynamicPPL.getconditioned_nested(cond_ctx, @varname(x)) |
| 74 | +``` |
| 75 | + |
| 76 | +These functions are in turn used by the function `DynamicPPL.contextual_isassumption`, which is largely the same as `hasconditioned_nested`, but also checks whether the value is `missing` (in which case it isn't really conditioned). |
| 77 | + |
| 78 | +```{julia} |
| 79 | +DynamicPPL.contextual_isassumption(cond_ctx, @varname(x)) |
| 80 | +``` |
| 81 | + |
| 82 | +::: {.callout-note} |
| 83 | +Notice that (neglecting `missing` values) the return value of `contextual_isassumption` is the _opposite_ of `hasconditioned_nested`, i.e. for a variable that _is_ conditioned on, `contextual_isassumption` returns `false`. |
| 84 | +::: |
| 85 | + |
| 86 | +If a variable `x` is conditioned on, then the effect of this is to set the value of `x` to the given value (while still including its contribution to the log probability density). |
| 87 | +Since `x` is no longer a random variable, if we were to evaluate the model, we would find only one key in the `VarInfo`: |
| 88 | + |
| 89 | +```{julia} |
| 90 | +keys(VarInfo(cond_model)) |
| 91 | +``` |
| 92 | + |
| 93 | +## Joint behaviour: desiderata at the model level |
| 94 | + |
| 95 | +When paired together, these two contexts have the potential to cause substantial confusion: `PrefixContext` modifies the variable names that are seen, which may cause them to be out of sync with the values contained inside the `ConditionContext`. |
| 96 | + |
| 97 | +We begin by mentioning some high-level desiderata for their joint behaviour. |
| 98 | +Take these models, for example: |
| 99 | + |
| 100 | +```{julia} |
| 101 | +# We define a helper function to unwrap a layer of SamplingContext, to |
| 102 | +# avoid cluttering the print statements. |
| 103 | +unwrap_sampling_context(ctx::DynamicPPL.SamplingContext) = ctx.context |
| 104 | +unwrap_sampling_context(ctx::DynamicPPL.AbstractContext) = ctx |
| 105 | +
|
| 106 | +@model function inner() |
| 107 | + println("inner context: $(unwrap_sampling_context(__context__))") |
| 108 | + x ~ Normal() |
| 109 | + return y ~ Normal() |
| 110 | +end |
| 111 | +
|
| 112 | +@model function outer() |
| 113 | + println("outer context: $(unwrap_sampling_context(__context__))") |
| 114 | + return a ~ to_submodel(inner()) |
| 115 | +end |
| 116 | +
|
| 117 | +# 'Outer conditioning' |
| 118 | +with_outer_cond = outer() | (@varname(a.x) => 1.0) |
| 119 | +
|
| 120 | +# 'Inner conditioning' |
| 121 | +inner_cond = inner() | (@varname(x) => 1.0) |
| 122 | +@model function outer2() |
| 123 | + println("outer context: $(unwrap_sampling_context(__context__))") |
| 124 | + return a ~ to_submodel(inner_cond) |
| 125 | +end |
| 126 | +with_inner_cond = outer2() |
| 127 | +``` |
| 128 | + |
| 129 | +We want that: |
| 130 | + |
| 131 | + 1. `keys(VarInfo(outer()))` should return `[a.x, a.y]`; |
| 132 | + 2. `keys(VarInfo(with_outer_cond))` should return `[a.y]`; |
| 133 | + 3. `keys(VarInfo(with_inner_cond))` should return `[a.y]`, |
| 134 | + |
| 135 | +**In other words, we can condition submodels either from the outside (point (2)) or from the inside (point (3)), and the variable name we use to specify the conditioning should match the level at which we perform the conditioning.** |
| 136 | + |
| 137 | +This is an incredibly salient point because it means that submodels can be treated as individual, opaque objects, and we can condition them without needing to know what it will be prefixed with, or the context in which that submodel is being used. |
| 138 | +For example, this means we can reuse `inner_cond` in another model with a different prefix, and it will _still_ have its inner `x` value be conditioned, despite the prefix differing. |
| 139 | + |
| 140 | +::: {.callout-note} |
| 141 | +In the current version of DynamicPPL, these criteria are all fulfilled. |
| 142 | +However, this was not the case in the past: in particular, point (3) was not fulfilled, and users had to condition the internal submodel with the prefixes that were used outside. |
| 143 | +(See [this GitHub issue](https://github.com/TuringLang/DynamicPPL.jl/issues/857) for more information; this issue was the direct motivation for this documentation page.) |
| 144 | +::: |
| 145 | + |
| 146 | +## Desiderata at the context level |
| 147 | + |
| 148 | +The above section describes how we expect conditioning and prefixing to behave from a user's perpective. |
| 149 | +We now turn to the question of how we implement this in terms of DynamicPPL contexts. |
| 150 | +We do not specify the implementation details here, but we will sketch out something resembling an API that will allow us to achieve the target behaviour. |
| 151 | + |
| 152 | +**Point (1)** does not involve any conditioning, only prefixing; it is therefore already satisfied by virtue of the `tilde_assume` method shown above. |
| 153 | + |
| 154 | +**Points (2) and (3)** are more tricky. |
| 155 | +As the reader may surmise, the difference between them is the order in which the contexts are stacked. |
| 156 | + |
| 157 | +For the _outer_ conditioning case (point (2)), the `ConditionContext` will contain a `VarName` that is already prefixed. |
| 158 | +When we enter the inner submodel, this `ConditionContext` has to be passed down and somehow combined with the `PrefixContext` that is created when we enter the submodel. |
| 159 | +We make the claim here that the best way to do this is to nest the `PrefixContext` _inside_ the `ConditionContext`. |
| 160 | +This is indeed what happens, as can be demonstrated by running the model. |
| 161 | + |
| 162 | +```{julia} |
| 163 | +with_outer_cond() |
| 164 | +``` |
| 165 | + |
| 166 | +For the _inner_ conditioning case (point (3)), the outer model is not run with any special context. |
| 167 | +The inner model will itself contain a `ConditionContext` will contain a `VarName` that is not prefixed. |
| 168 | +When we run the model, this `ConditionContext` should be then nested _inside_ a `PrefixContext` to form the final evaluation context. |
| 169 | +Again, we can run the model to see this in action: |
| 170 | + |
| 171 | +```{julia} |
| 172 | +with_inner_cond() |
| 173 | +``` |
| 174 | + |
| 175 | +Putting all of the information so far together, what it means is that if we have these two inner contexts (taken from above): |
| 176 | + |
| 177 | +```{julia} |
| 178 | +using DynamicPPL: PrefixContext, ConditionContext, DefaultContext |
| 179 | +
|
| 180 | +inner_ctx_with_outer_cond = ConditionContext( |
| 181 | + Dict(@varname(a.x) => 1.0), PrefixContext(@varname(a)) |
| 182 | +) |
| 183 | +inner_ctx_with_inner_cond = PrefixContext( |
| 184 | + @varname(a), ConditionContext(Dict(@varname(x) => 1.0)) |
| 185 | +) |
| 186 | +``` |
| 187 | + |
| 188 | +then we want both of these to be `true` (and thankfully, they are!): |
| 189 | + |
| 190 | +```{julia} |
| 191 | +DynamicPPL.hasconditioned_nested(inner_ctx_with_outer_cond, @varname(a.x)) |
| 192 | +``` |
| 193 | + |
| 194 | +```{julia} |
| 195 | +DynamicPPL.hasconditioned_nested(inner_ctx_with_inner_cond, @varname(a.x)) |
| 196 | +``` |
| 197 | + |
| 198 | +This allows us to finally specify our task as follows: |
| 199 | + |
| 200 | +(1) Given the correct arguments, we need to make sure that `hasconditioned_nested` and `getconditioned_nested` behave correctly. |
| 201 | + |
| 202 | +(2) We need to make sure that both the correct arguments are supplied. In order to do so: |
| 203 | + |
| 204 | + - (2a) We need to make sure that when evaluating a submodel, the context stack is arranged such that `PrefixContext` is applied _inside_ the parent model's context, but _outside_ the submodel's own context. |
| 205 | + |
| 206 | + - (2b) We also need to make sure that the `VarName` passed to it is prefixed correctly. |
| 207 | + |
| 208 | +## How do we do it? |
| 209 | + |
| 210 | +(1) `hasconditioned_nested` and `getconditioned_nested` accomplish this by first 'collapsing' the context stack, i.e. they go through the context stack, remove all `PrefixContext`s, and apply those prefixes to any conditioned variables below it in the stack. |
| 211 | +Once the `PrefixContext`s have been removed, one can then iterate through the context stack and check if any of the `ConditionContext`s contain the variable, or get the value itself. |
| 212 | +For more details the reader is encouraged to read the source code. |
| 213 | + |
| 214 | +(2a) We ensure that the context stack is correctly arranged by relying on the behaviour of `make_evaluate_args_and_kwargs`. |
| 215 | +This function is called whenever a model (which itself contains a context) is evaluated with a separate ('external') context, and makes sure to arrange both of these contexts such that _the model's context is nested inside the external context_. |
| 216 | +Thus, as long as prefixing is implemented by applying a `PrefixContext` on the outermost layer of the _inner_ model context, this will be correctly combined with an external context to give the behaviour seen above. |
| 217 | + |
| 218 | +(2b) At first glance, it seems like `tilde_assume` can take care of the `VarName` prefixing for us (as described in the first section). |
| 219 | +However, this is not actually the case: `contextual_isassumption`, which is the function that calls `hasconditioned_nested`, is much higher in the call stack than `tilde_assume` is. |
| 220 | +So, we need to explicitly prefix it before passing it to `contextual_isassumption`. |
| 221 | +This is done inside the `@model` macro, or technically, its subsidiary function `isassumption`. |
| 222 | + |
| 223 | +## Nested submodels |
| 224 | + |
| 225 | +Just in case the above wasn't complicated enough, we need to also be very careful when dealing with nested submodels, which have multiple layers of `PrefixContext`s which may be interspersed with `ConditionContext`s. |
| 226 | +For example, in this series of nested submodels, |
| 227 | + |
| 228 | +```{julia} |
| 229 | +@model function charlie() |
| 230 | + x ~ Normal() |
| 231 | + y ~ Normal() |
| 232 | + return z ~ Normal() |
| 233 | +end |
| 234 | +@model function bravo() |
| 235 | + return b ~ to_submodel(charlie() | (@varname(x) => 1.0)) |
| 236 | +end |
| 237 | +@model function alpha() |
| 238 | + return a ~ to_submodel(bravo() | (@varname(b.y) => 1.0)) |
| 239 | +end |
| 240 | +``` |
| 241 | + |
| 242 | +we expect that the only variable to be sampled should be `z` inside `charlie`, or rather, `a.b.z` once it has been through the prefixes. |
| 243 | + |
| 244 | +```{julia} |
| 245 | +keys(VarInfo(alpha())) |
| 246 | +``` |
| 247 | + |
| 248 | +The general strategy that we adopt is similar to above. |
| 249 | +Following the principle that `PrefixContext` should be nested inside the outer context, but outside the inner submodel's context, we can infer that the correct context inside `charlie` should be: |
| 250 | + |
| 251 | +```{julia} |
| 252 | +big_ctx = PrefixContext( |
| 253 | + @varname(a), |
| 254 | + ConditionContext( |
| 255 | + Dict(@varname(b.y) => 1.0), |
| 256 | + PrefixContext(@varname(b), ConditionContext(Dict(@varname(x) => 1.0))), |
| 257 | + ), |
| 258 | +) |
| 259 | +``` |
| 260 | + |
| 261 | +We need several things to work correctly here: we need the `VarName` prefixing to behave correctly, and then we need to implement `hasconditioned_nested` and `getconditioned_nested` on the resulting prefixed `VarName`. |
| 262 | +It turns out that the prefixing itself is enough to illustrate the most important point in this section, namely, the need to traverse the context stack in a _different direction_ to what most of DynamicPPL does. |
| 263 | + |
| 264 | +Let's work with a function called `myprefix(::AbstractContext, ::VarName)` (to avoid confusion with any existing DynamicPPL function). |
| 265 | +We should like `myprefix(big_ctx, @varname(x))` to return `@varname(a.b.x)`. |
| 266 | +Consider the following naive implementation, which mirrors a lot of code in the tilde-pipeline: |
| 267 | + |
| 268 | +```{julia} |
| 269 | +using DynamicPPL: NodeTrait, IsLeaf, IsParent, childcontext, AbstractContext |
| 270 | +using AbstractPPL: AbstractPPL |
| 271 | +
|
| 272 | +function myprefix(ctx::DynamicPPL.AbstractContext, vn::VarName) |
| 273 | + return myprefix(NodeTrait(ctx), ctx, vn) |
| 274 | +end |
| 275 | +function myprefix(::IsLeaf, ::AbstractContext, vn::VarName) |
| 276 | + return vn |
| 277 | +end |
| 278 | +function myprefix(::IsParent, ctx::AbstractContext, vn::VarName) |
| 279 | + return myprefix(childcontext(ctx), vn) |
| 280 | +end |
| 281 | +function myprefix(ctx::DynamicPPL.PrefixContext, vn::VarName) |
| 282 | + # The functionality to actually manipulate the VarNames is in AbstractPPL |
| 283 | + new_vn = AbstractPPL.prefix(vn, ctx.vn_prefix) |
| 284 | + # Then pass to the child context |
| 285 | + return myprefix(childcontext(ctx), new_vn) |
| 286 | +end |
| 287 | +
|
| 288 | +myprefix(big_ctx, @varname(x)) |
| 289 | +``` |
| 290 | + |
| 291 | +This implementation clearly is not correct, because it applies the _inner_ `PrefixContext` before the outer one. |
| 292 | + |
| 293 | +The right way to implement `myprefix` is to, essentially, reverse the order of two lines above: |
| 294 | + |
| 295 | +```{julia} |
| 296 | +function myprefix(ctx::DynamicPPL.PrefixContext, vn::VarName) |
| 297 | + # Pass to the child context first |
| 298 | + new_vn = myprefix(childcontext(ctx), vn) |
| 299 | + # Then apply this context's prefix |
| 300 | + return AbstractPPL.prefix(new_vn, ctx.vn_prefix) |
| 301 | +end |
| 302 | +
|
| 303 | +myprefix(big_ctx, @varname(x)) |
| 304 | +``` |
| 305 | + |
| 306 | +This is a much better result! |
| 307 | +The implementation of related functions such as `hasconditioned_nested` and `getconditioned_nested`, under the hood, use a similar recursion scheme, so you will find that this is a common pattern when reading the source code of various prefixing-related functions. |
| 308 | +When editing this code, it is worth being mindful of this as a potential source of incorrectness. |
| 309 | + |
| 310 | +::: {.callout-note} |
| 311 | +If you have encountered left and right folds, the above discussion illustrates the difference between them: the wrong implementation of `myprefix` uses a left fold (which collects prefixes in the opposite order from which they are encountered), while the correct implementation uses a right fold. |
| 312 | +::: |
| 313 | + |
| 314 | +## Loose ends 1: Manual prefixing |
| 315 | + |
| 316 | +Sometimes users may want to manually prefix a model, for example: |
| 317 | + |
| 318 | +```{julia} |
| 319 | +@model function inner_manual() |
| 320 | + x ~ Normal() |
| 321 | + return y ~ Normal() |
| 322 | +end |
| 323 | +
|
| 324 | +@model function outer_manual() |
| 325 | + return _unused ~ to_submodel(prefix(inner_manual(), :a), false) |
| 326 | +end |
| 327 | +``` |
| 328 | + |
| 329 | +In this case, the `VarName` on the left-hand side of the tilde is not used, and the prefix is instead specified using the `prefix` function. |
| 330 | + |
| 331 | +The way to deal with this follows on from the previous discussion. |
| 332 | +Specifically, we said that: |
| 333 | + |
| 334 | +> [...] as long as prefixing is implemented by applying a `PrefixContext` on the outermost layer of the _inner_ model context, this will be correctly combined [...] |
| 335 | +
|
| 336 | +When automatic prefixing is used, this application of `PrefixContext` occurs inside the `tilde_assume!!` method. |
| 337 | +In the manual prefixing case, we need to make sure that `prefix(submodel::Model, ::Symbol)` does the same thing, i.e. it inserts a `PrefixContext` at the outermost layer of `submodel`'s context. |
| 338 | +We can see that this is precisely what happens: |
| 339 | + |
| 340 | +```{julia} |
| 341 | +@model f() = x ~ Normal() |
| 342 | +
|
| 343 | +model = f() |
| 344 | +prefixed_model = prefix(model, :a) |
| 345 | +
|
| 346 | +(model.context, prefixed_model.context) |
| 347 | +``` |
| 348 | + |
| 349 | +## Loose ends 2: FixedContext |
| 350 | + |
| 351 | +Finally, note that all of the above also applies to the interaction between `PrefixContext` and `FixedContext`, except that the functions have different names. |
| 352 | +(`FixedContext` behaves the same way as `ConditionContext`, except that unlike conditioned variables, fixed variables do not contribute to the log probability density.) |
| 353 | +This generally results in a large amount of code duplication, but the concepts that underlie both contexts are exactly the same. |
0 commit comments