Skip to content

Commit f7884f1

Browse files
committed
Make PrefixContext contain a varname rather than symbol
1 parent 8c3bff4 commit f7884f1

File tree

6 files changed

+95
-70
lines changed

6 files changed

+95
-70
lines changed

docs/src/internals/submodel_condition.md

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,10 @@ Putting all of the information so far together, what it means is that if we have
181181
using DynamicPPL: PrefixContext, ConditionContext, DefaultContext
182182
183183
inner_ctx_with_outer_cond = ConditionContext(
184-
Dict(@varname(a.x) => 1.0), PrefixContext{:a}(DefaultContext())
184+
Dict(@varname(a.x) => 1.0), PrefixContext(@varname(a))
185185
)
186-
inner_ctx_with_inner_cond = PrefixContext{:a}(
187-
ConditionContext(Dict(@varname(x) => 1.0), DefaultContext())
186+
inner_ctx_with_inner_cond = PrefixContext(
187+
@varname(a), ConditionContext(Dict(@varname(x) => 1.0))
188188
)
189189
```
190190

@@ -252,10 +252,11 @@ The general strategy that we adopt is similar to above.
252252
Following the principle that `PrefixContext` should be nested inside the outer context, but outside the inner submodel's context, we can infer that the correct context inside `charlie` should be:
253253

254254
```@example
255-
big_ctx = PrefixContext{:a}(
255+
big_ctx = PrefixContext(
256+
@varname(a),
256257
ConditionContext(
257258
Dict(@varname(b.y) => 1.0),
258-
PrefixContext{:b}(ConditionContext(Dict(@varname(x) => 1.0))),
259+
PrefixContext(@varname(b), ConditionContext(Dict(@varname(x) => 1.0))),
259260
),
260261
)
261262
```
@@ -280,9 +281,9 @@ end
280281
function myprefix(::IsParent, ctx::AbstractContext, vn::VarName)
281282
return myprefix(childcontext(ctx), vn)
282283
end
283-
function myprefix(ctx::DynamicPPL.PrefixContext{Prefix}, vn::VarName) where {Prefix}
284+
function myprefix(ctx::DynamicPPL.PrefixContext, vn::VarName)
284285
# The functionality to actually manipulate the VarNames is in AbstractPPL
285-
new_vn = AbstractPPL.prefix(vn, VarName{Prefix}())
286+
new_vn = AbstractPPL.prefix(vn, ctx.vn_prefix)
286287
# Then pass to the child context
287288
return myprefix(childcontext(ctx), new_vn)
288289
end
@@ -295,11 +296,11 @@ This implementation clearly is not correct, because it applies the _inner_ `Pref
295296
The right way to implement `myprefix` is to, essentially, reverse the order of two lines above:
296297

297298
```@example
298-
function myprefix(ctx::DynamicPPL.PrefixContext{Prefix}, vn::VarName) where {Prefix}
299+
function myprefix(ctx::DynamicPPL.PrefixContext, vn::VarName)
299300
# Pass to the child context first
300301
new_vn = myprefix(childcontext(ctx), vn)
301302
# Then apply this context's prefix
302-
return AbstractPPL.prefix(new_vn, VarName{Prefix}())
303+
return AbstractPPL.prefix(new_vn, ctx.vn_prefix)
303304
end
304305
305306
myprefix(big_ctx, @varname(x))

src/context_implementations.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@ function tilde_assume!!(context, right, vn, vi)
131131
# change in the future.
132132
if should_auto_prefix(right)
133133
dppl_model = right.model.model # This isa DynamicPPL.Model
134-
prefixed_submodel_context = PrefixContext{getsym(vn)}(dppl_model.context)
134+
# TODO: This does _not_ correctly prefix varnames with non-identity lenses
135+
prefixed_submodel_context = PrefixContext(vn, dppl_model.context)
135136
new_dppl_model = contextualize(dppl_model, prefixed_submodel_context)
136137
right = to_submodel(new_dppl_model, true)
137138
end

src/contexts.jl

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -237,36 +237,43 @@ function setchildcontext(parent::MiniBatchContext, child)
237237
end
238238

239239
"""
240-
PrefixContext{Prefix}(context)
240+
PrefixContext(vn::VarName[, context::AbstractContext])
241+
PrefixContext(vn::Val{sym}[, context::AbstractContext]) where {sym}
241242
242243
Create a context that allows you to use the wrapped `context` when running the model and
243-
adds the `Prefix` to all parameters.
244+
prefixes all parameters with the VarName `vn`.
245+
246+
`PrefixContext(Val(:a), context)` is equivalent to `PrefixContext(@varname(a), context)`.
247+
If `context` is not provided, it defaults to `DefaultContext()`.
244248
245249
This context is useful in nested models to ensure that the names of the parameters are
246250
unique.
247251
248252
See also: [`to_submodel`](@ref)
249253
"""
250-
struct PrefixContext{Prefix,C} <: AbstractContext
254+
struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractContext
255+
vn_prefix::Tvn
251256
context::C
252257
end
253-
function PrefixContext{Prefix}(context::AbstractContext) where {Prefix}
254-
return PrefixContext{Prefix,typeof(context)}(context)
258+
PrefixContext(vn::VarName) = PrefixContext(vn, DefaultContext())
259+
function PrefixContext(::Val{sym}, context::AbstractContext) where {sym}
260+
return PrefixContext(VarName{sym}(), context)
255261
end
262+
PrefixContext(::Val{sym}) where {sym} = PrefixContext(VarName{sym}())
256263

257264
NodeTrait(::PrefixContext) = IsParent()
258265
childcontext(context::PrefixContext) = context.context
259-
function setchildcontext(::PrefixContext{Prefix}, child) where {Prefix}
260-
return PrefixContext{Prefix}(child)
266+
function setchildcontext(ctx::PrefixContext, child::AbstractContext)
267+
return PrefixContext(ctx.vn_prefix, child)
261268
end
262269

263270
"""
264271
prefix(ctx::AbstractContext, vn::VarName)
265272
266273
Apply the prefixes in the context `ctx` to the variable name `vn`.
267274
"""
268-
function prefix(ctx::PrefixContext{Prefix}, vn::VarName) where {Prefix}
269-
return AbstractPPL.prefix(prefix(childcontext(ctx), vn), VarName{Prefix}())
275+
function prefix(ctx::PrefixContext, vn::VarName)
276+
return AbstractPPL.prefix(prefix(childcontext(ctx), vn), ctx.vn_prefix)
270277
end
271278
function prefix(ctx::AbstractContext, vn::VarName)
272279
return prefix(NodeTrait(ctx), ctx, vn)
@@ -295,14 +302,13 @@ not_ need to modify any inner `ConditionContext`s and `FixedContext`s. If you
295302
_do_ need to modify them, then you may need to use
296303
`prefix_cond_and_fixed_variables` instead.
297304
"""
298-
function prefix_and_strip_contexts(ctx::PrefixContext{Prefix}, vn::VarName) where {Prefix}
305+
function prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName)
299306
child_context = childcontext(ctx)
300307
# vn_prefixed contains the prefixes from all lower levels
301308
vn_prefixed, child_context_without_prefixes = prefix_and_strip_contexts(
302309
child_context, vn
303310
)
304-
return AbstractPPL.prefix(vn_prefixed, VarName{Prefix}()),
305-
child_context_without_prefixes
311+
return AbstractPPL.prefix(vn_prefixed, ctx.vn_prefix), child_context_without_prefixes
306312
end
307313
function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName)
308314
return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn)
@@ -314,11 +320,15 @@ function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName
314320
end
315321

316322
"""
317-
prefix(model::Model, x)
318-
319-
Return `model` but with all random variables prefixed by `x`.
323+
prefix(model::Model, x::VarName)
324+
prefix(model::Model, x::Val{sym})
325+
prefix(model::Model, x::Symbol) # Not recommended
320326
321-
If `x` is known at compile-time, use `Val{x}()` to avoid runtime overheads for prefixing.
327+
Return `model` but with all random variables prefixed by `x`, where `x` is either:
328+
- a `VarName` (e.g. `@varname(a)`),
329+
- a `Val{sym}` (e.g. `Val(:a)`), or
330+
- a `Symbol` (e.g. `:a`). This last method is not recommended as it
331+
introduces runtime overheads.
322332
323333
# Examples
324334
@@ -328,17 +338,19 @@ julia> using DynamicPPL: prefix
328338
julia> @model demo() = x ~ Dirac(1)
329339
demo (generic function with 2 methods)
330340
331-
julia> rand(prefix(demo(), :my_prefix))
341+
julia> rand(prefix(demo(), @my_prefix))
332342
(var"my_prefix.x" = 1,)
333343
334-
julia> # One can also use `Val` to avoid runtime overheads.
335-
rand(prefix(demo(), Val(:my_prefix)))
344+
julia> rand(prefix(demo(), Val(:my_prefix)))
336345
(var"my_prefix.x" = 1,)
337346
```
338347
"""
339-
prefix(model::Model, x) = contextualize(model, PrefixContext{Symbol(x)}(model.context))
340-
function prefix(model::Model, ::Val{x}) where {x}
341-
return contextualize(model, PrefixContext{Symbol(x)}(model.context))
348+
prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context))
349+
function prefix(model::Model, x::Symbol)
350+
return contextualize(model, PrefixContext(VarName{x}(), model.context))
351+
end
352+
function prefix(model::Model, x::Val{sym}) where {sym}
353+
return contextualize(model, PrefixContext(VarName{sym}(), model.context))
342354
end
343355

344356
"""
@@ -426,7 +438,7 @@ hasconditioned_nested(::IsLeaf, context, vn) = hasconditioned(context, vn)
426438
function hasconditioned_nested(::IsParent, context, vn)
427439
return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn)
428440
end
429-
function hasconditioned_nested(context::PrefixContext{Prefix}, vn) where {Prefix}
441+
function hasconditioned_nested(context::PrefixContext, vn)
430442
return hasconditioned_nested(collapse_prefix_stack(context), vn)
431443
end
432444

@@ -444,7 +456,7 @@ end
444456
function getconditioned_nested(::IsLeaf, context, vn)
445457
return error("context $(context) does not contain value for $vn")
446458
end
447-
function getconditioned_nested(context::PrefixContext{Prefix}, vn) where {Prefix}
459+
function getconditioned_nested(context::PrefixContext, vn)
448460
return getconditioned_nested(collapse_prefix_stack(context), vn)
449461
end
450462
function getconditioned_nested(::IsParent, context, vn)
@@ -715,13 +727,13 @@ which explains this in much more detail.
715727
```jldoctest
716728
julia> using DynamicPPL: collapse_prefix_stack
717729
718-
julia> c1 = PrefixContext{:a}(ConditionContext((x=1, )));
730+
julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, )));
719731
720732
julia> collapse_prefix_stack(c1)
721733
ConditionContext(Dict(a.x => 1), DefaultContext())
722734
723735
julia> # Here, `x` gets prefixed only with `a`, whereas `y` is prefixed with both.
724-
c2 = PrefixContext{:a}(ConditionContext((x=1, ), PrefixContext{:b}(ConditionContext((y=2,)))));
736+
c2 = PrefixContext(@varname(a), ConditionContext((x=1, ), PrefixContext(@varname(b), ConditionContext((y=2,)))));
725737
726738
julia> collapsed = collapse_prefix_stack(c2);
727739
@@ -733,14 +745,14 @@ julia> # `collapsed` really looks something like this:
733745
(1, 2)
734746
```
735747
"""
736-
function collapse_prefix_stack(context::PrefixContext{Prefix}) where {Prefix}
748+
function collapse_prefix_stack(context::PrefixContext)
737749
# Collapse the child context (thus applying any inner prefixes first)
738750
collapsed = collapse_prefix_stack(childcontext(context))
739751
# Prefix any conditioned variables with the current prefix
740752
# Note: prefix_conditioned_variables is O(N) in the depth of the context stack.
741753
# So is this function. In the worst case scenario, this is O(N^2) in the
742754
# depth of the context stack.
743-
return prefix_cond_and_fixed_variables(collapsed, VarName{Prefix}())
755+
return prefix_cond_and_fixed_variables(collapsed, context.vn_prefix)
744756
end
745757
function collapse_prefix_stack(context::AbstractContext)
746758
return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context)

src/model.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ julia> # Nested ones also work.
429429
# (Note that `PrefixContext` also prefixes the variables of any
430430
# ConditionContext that is _inside_ it; because of this, the type of the
431431
# container has to be broadened to a `Dict`.)
432-
cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((m=1.0,)))), x=100.0);
432+
cm = condition(contextualize(m, PrefixContext(@varname(a), ConditionContext((m=1.0,)))), x=100.0);
433433
434434
julia> Set(keys(conditioned(cm))) == Set([@varname(a.m), @varname(x)])
435435
true
@@ -441,7 +441,7 @@ julia> # Since we conditioned on `a.m`, it is not treated as a random variable.
441441
a.x
442442
443443
julia> # We can also condition on `a.m` _outside_ of the PrefixContext:
444-
cm = condition(contextualize(m, PrefixContext{:a}(DefaultContext())), (@varname(a.m) => 1.0));
444+
cm = condition(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0));
445445
446446
julia> conditioned(cm)
447447
Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry:
@@ -769,7 +769,7 @@ julia> # Returns all the variables we have fixed on + their values.
769769
(x = 100.0, m = 1.0)
770770
771771
julia> # The rest of this is the same as the `condition` example above.
772-
cm = fix(contextualize(m, PrefixContext{:a}(fix(m=1.0))), x=100.0);
772+
cm = fix(contextualize(m, PrefixContext(@varname(a), fix(m=1.0))), x=100.0);
773773
774774
julia> Set(keys(fixed(cm))) == Set([@varname(a.m), @varname(x)])
775775
true
@@ -779,7 +779,7 @@ julia> keys(VarInfo(cm))
779779
a.x
780780
781781
julia> # We can also condition on `a.m` _outside_ of the PrefixContext:
782-
cm = fix(contextualize(m, PrefixContext{:a}(DefaultContext())), (@varname(a.m) => 1.0));
782+
cm = fix(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0));
783783
784784
julia> fixed(cm)
785785
Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry:

src/submodel_macro.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,12 @@ end
223223
prefix_submodel_context(prefix, left, ctx) = prefix_submodel_context(prefix, ctx)
224224
function prefix_submodel_context(prefix, ctx)
225225
# E.g. `prefix="asd[$i]"` or `prefix=asd` with `asd` to be evaluated.
226-
return :($(PrefixContext){$(Symbol)($(esc(prefix)))}($ctx))
226+
return :($(PrefixContext)($(Symbol)($(esc(prefix))), $ctx))
227227
end
228228

229229
function prefix_submodel_context(prefix::Union{AbstractString,Symbol}, ctx)
230230
# E.g. `prefix="asd"`.
231-
return :($(PrefixContext){$(esc(Meta.quot(Symbol(prefix))))}($ctx))
231+
return :($(PrefixContext)($(esc(Meta.quot(Symbol(prefix)))), $ctx))
232232
end
233233

234234
function prefix_submodel_context(prefix::Bool, ctx)

0 commit comments

Comments
 (0)