@@ -237,36 +237,43 @@ function setchildcontext(parent::MiniBatchContext, child)
237
237
end
238
238
239
239
"""
240
- PrefixContext{Prefix}(context)
240
+ PrefixContext(vn::VarName[, context::AbstractContext])
241
+ PrefixContext(vn::Val{sym}[, context::AbstractContext]) where {sym}
241
242
242
243
Create a context that allows you to use the wrapped `context` when running the model and
243
- adds the `Prefix` to all parameters.
244
+ prefixes all parameters with the VarName `vn`.
245
+
246
+ `PrefixContext(Val(:a), context)` is equivalent to `PrefixContext(@varname(a), context)`.
247
+ If `context` is not provided, it defaults to `DefaultContext()`.
244
248
245
249
This context is useful in nested models to ensure that the names of the parameters are
246
250
unique.
247
251
248
252
See also: [`to_submodel`](@ref)
249
253
"""
250
- struct PrefixContext{Prefix,C} <: AbstractContext
254
+ struct PrefixContext{Tvn<: VarName ,C<: AbstractContext } <: AbstractContext
255
+ vn_prefix:: Tvn
251
256
context:: C
252
257
end
253
- function PrefixContext {Prefix} (context:: AbstractContext ) where {Prefix}
254
- return PrefixContext {Prefix,typeof(context)} (context)
258
+ PrefixContext (vn:: VarName ) = PrefixContext (vn, DefaultContext ())
259
+ function PrefixContext (:: Val{sym} , context:: AbstractContext ) where {sym}
260
+ return PrefixContext (VarName {sym} (), context)
255
261
end
262
+ PrefixContext (:: Val{sym} ) where {sym} = PrefixContext (VarName {sym} ())
256
263
257
264
NodeTrait (:: PrefixContext ) = IsParent ()
258
265
childcontext (context:: PrefixContext ) = context. context
259
- function setchildcontext (:: PrefixContext{Prefix} , child) where {Prefix}
260
- return PrefixContext {Prefix} ( child)
266
+ function setchildcontext (ctx :: PrefixContext , child:: AbstractContext )
267
+ return PrefixContext (ctx . vn_prefix, child)
261
268
end
262
269
263
270
"""
264
271
prefix(ctx::AbstractContext, vn::VarName)
265
272
266
273
Apply the prefixes in the context `ctx` to the variable name `vn`.
267
274
"""
268
- function prefix (ctx:: PrefixContext{Prefix} , vn:: VarName ) where {Prefix}
269
- return AbstractPPL. prefix (prefix (childcontext (ctx), vn), VarName {Prefix} () )
275
+ function prefix (ctx:: PrefixContext , vn:: VarName )
276
+ return AbstractPPL. prefix (prefix (childcontext (ctx), vn), ctx . vn_prefix )
270
277
end
271
278
function prefix (ctx:: AbstractContext , vn:: VarName )
272
279
return prefix (NodeTrait (ctx), ctx, vn)
@@ -295,14 +302,13 @@ not_ need to modify any inner `ConditionContext`s and `FixedContext`s. If you
295
302
_do_ need to modify them, then you may need to use
296
303
`prefix_cond_and_fixed_variables` instead.
297
304
"""
298
- function prefix_and_strip_contexts (ctx:: PrefixContext{Prefix} , vn:: VarName ) where {Prefix}
305
+ function prefix_and_strip_contexts (ctx:: PrefixContext , vn:: VarName )
299
306
child_context = childcontext (ctx)
300
307
# vn_prefixed contains the prefixes from all lower levels
301
308
vn_prefixed, child_context_without_prefixes = prefix_and_strip_contexts (
302
309
child_context, vn
303
310
)
304
- return AbstractPPL. prefix (vn_prefixed, VarName {Prefix} ()),
305
- child_context_without_prefixes
311
+ return AbstractPPL. prefix (vn_prefixed, ctx. vn_prefix), child_context_without_prefixes
306
312
end
307
313
function prefix_and_strip_contexts (ctx:: AbstractContext , vn:: VarName )
308
314
return prefix_and_strip_contexts (NodeTrait (ctx), ctx, vn)
@@ -314,11 +320,15 @@ function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName
314
320
end
315
321
316
322
"""
317
- prefix(model::Model, x)
318
-
319
- Return `model` but with all random variables prefixed by `x`.
323
+ prefix(model::Model, x::VarName )
324
+ prefix(model::Model, x::Val{sym})
325
+ prefix(model::Model, x::Symbol) # Not recommended
320
326
321
- If `x` is known at compile-time, use `Val{x}()` to avoid runtime overheads for prefixing.
327
+ Return `model` but with all random variables prefixed by `x`, where `x` is either:
328
+ - a `VarName` (e.g. `@varname(a)`),
329
+ - a `Val{sym}` (e.g. `Val(:a)`), or
330
+ - a `Symbol` (e.g. `:a`). This last method is not recommended as it
331
+ introduces runtime overheads.
322
332
323
333
# Examples
324
334
@@ -328,17 +338,19 @@ julia> using DynamicPPL: prefix
328
338
julia> @model demo() = x ~ Dirac(1)
329
339
demo (generic function with 2 methods)
330
340
331
- julia> rand(prefix(demo(), : my_prefix))
341
+ julia> rand(prefix(demo(), @ my_prefix))
332
342
(var"my_prefix.x" = 1,)
333
343
334
- julia> # One can also use `Val` to avoid runtime overheads.
335
- rand(prefix(demo(), Val(:my_prefix)))
344
+ julia> rand(prefix(demo(), Val(:my_prefix)))
336
345
(var"my_prefix.x" = 1,)
337
346
```
338
347
"""
339
- prefix (model:: Model , x) = contextualize (model, PrefixContext {Symbol(x)} (model. context))
340
- function prefix (model:: Model , :: Val{x} ) where {x}
341
- return contextualize (model, PrefixContext {Symbol(x)} (model. context))
348
+ prefix (model:: Model , x:: VarName ) = contextualize (model, PrefixContext (x, model. context))
349
+ function prefix (model:: Model , x:: Symbol )
350
+ return contextualize (model, PrefixContext (VarName {x} (), model. context))
351
+ end
352
+ function prefix (model:: Model , x:: Val{sym} ) where {sym}
353
+ return contextualize (model, PrefixContext (VarName {sym} (), model. context))
342
354
end
343
355
344
356
"""
@@ -426,7 +438,7 @@ hasconditioned_nested(::IsLeaf, context, vn) = hasconditioned(context, vn)
426
438
function hasconditioned_nested (:: IsParent , context, vn)
427
439
return hasconditioned (context, vn) || hasconditioned_nested (childcontext (context), vn)
428
440
end
429
- function hasconditioned_nested (context:: PrefixContext{Prefix} , vn) where {Prefix}
441
+ function hasconditioned_nested (context:: PrefixContext , vn)
430
442
return hasconditioned_nested (collapse_prefix_stack (context), vn)
431
443
end
432
444
444
456
function getconditioned_nested (:: IsLeaf , context, vn)
445
457
return error (" context $(context) does not contain value for $vn " )
446
458
end
447
- function getconditioned_nested (context:: PrefixContext{Prefix} , vn) where {Prefix}
459
+ function getconditioned_nested (context:: PrefixContext , vn)
448
460
return getconditioned_nested (collapse_prefix_stack (context), vn)
449
461
end
450
462
function getconditioned_nested (:: IsParent , context, vn)
@@ -715,13 +727,13 @@ which explains this in much more detail.
715
727
```jldoctest
716
728
julia> using DynamicPPL: collapse_prefix_stack
717
729
718
- julia> c1 = PrefixContext{:a}( ConditionContext((x=1, )));
730
+ julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, )));
719
731
720
732
julia> collapse_prefix_stack(c1)
721
733
ConditionContext(Dict(a.x => 1), DefaultContext())
722
734
723
735
julia> # Here, `x` gets prefixed only with `a`, whereas `y` is prefixed with both.
724
- c2 = PrefixContext{:a}( ConditionContext((x=1, ), PrefixContext{:b}( ConditionContext((y=2,)))));
736
+ c2 = PrefixContext(@varname(a), ConditionContext((x=1, ), PrefixContext(@varname(b), ConditionContext((y=2,)))));
725
737
726
738
julia> collapsed = collapse_prefix_stack(c2);
727
739
@@ -733,14 +745,14 @@ julia> # `collapsed` really looks something like this:
733
745
(1, 2)
734
746
```
735
747
"""
736
- function collapse_prefix_stack (context:: PrefixContext{Prefix} ) where {Prefix}
748
+ function collapse_prefix_stack (context:: PrefixContext )
737
749
# Collapse the child context (thus applying any inner prefixes first)
738
750
collapsed = collapse_prefix_stack (childcontext (context))
739
751
# Prefix any conditioned variables with the current prefix
740
752
# Note: prefix_conditioned_variables is O(N) in the depth of the context stack.
741
753
# So is this function. In the worst case scenario, this is O(N^2) in the
742
754
# depth of the context stack.
743
- return prefix_cond_and_fixed_variables (collapsed, VarName {Prefix} () )
755
+ return prefix_cond_and_fixed_variables (collapsed, context . vn_prefix )
744
756
end
745
757
function collapse_prefix_stack (context:: AbstractContext )
746
758
return collapse_prefix_stack (NodeTrait (collapse_prefix_stack, context), context)
0 commit comments