Skip to content

Commit a3f0923

Browse files
penelopeysmmhauru
andauthored
DynamicPPL 0.36 (#2535)
* DynamicPPL 0.36 * Fix prefixing test and docs * Fix deprecation warning for VarName(::Symbol) * Allow GibbsContext to wrap PrefixContext (but only PrefixContext) * Bump minor version instead * Enable non-identity VarNames in Gibbs Closes #2403 * Add Gibbs tests for non-identity VarNames and submodels * Update src/mcmc/gibbs.jl Co-authored-by: Markus Hauru <markus@mhauru.org> * Add changelog note about Gibbs * Add more non-identity varname tests --------- Co-authored-by: Markus Hauru <markus@mhauru.org>
1 parent 3901096 commit a3f0923

File tree

8 files changed

+184
-55
lines changed

8 files changed

+184
-55
lines changed

HISTORY.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,43 @@
1+
# Release 0.38.0
2+
3+
## DynamicPPL version
4+
5+
DynamicPPL compatibility has been bumped to 0.36.
6+
This brings with it a number of changes: the ones most likely to affect you are submodel prefixing and conditioning.
7+
Variables in submodels are now represented correctly with field accessors.
8+
For example:
9+
10+
```julia
11+
using Turing
12+
@model inner() = x ~ Normal()
13+
@model outer() = a ~ to_submodel(inner())
14+
```
15+
16+
`keys(VarInfo(outer()))` now returns `[@varname(a.x)]` instead of `[@varname(var"a.x")]`
17+
18+
Furthermore, you can now either condition on the outer model like `outer() | (@varname(a.x) => 1.0)`, or the inner model like `inner() | (@varname(x) => 1.0)`.
19+
If you use the conditioned inner model as a submodel, the conditioning will still apply correctly.
20+
21+
Please see [the DynamicPPL release notes](https://github.com/TuringLang/DynamicPPL.jl/releases/tag/v0.36.0) for fuller details.
22+
23+
## Gibbs sampler
24+
25+
Turing's Gibbs sampler now allows for more complex `VarName`s, such as `x[1]` or `x.a`, to be used.
26+
For example, you can now do this:
27+
28+
```julia
29+
@model function f()
30+
x = Vector{Float64}(undef, 2)
31+
x[1] ~ Normal()
32+
return x[2] ~ Normal()
33+
end
34+
sample(f(), Gibbs(@varname(x[1]) => MH(), @varname(x[2]) => MH()), 100)
35+
```
36+
37+
Performance for the cases which used to previously work (i.e. `VarName`s like `x` which only consist of a single symbol) is unaffected, and `VarNames` with only field accessors (e.g. `x.a`) should be equally fast.
38+
It is possible that `VarNames` with indexing (e.g. `x[1]`) may be slower (although this is still an improvement over not working at all!).
39+
If you find any cases where you think the performance is worse than it should be, please do file an issue.
40+
141
# Release 0.37.1
242

343
`maximum_a_posteriori` and `maximum_likelihood` now perform sanity checks on the model before running the optimisation.

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.37.1"
3+
version = "0.38.0"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -62,7 +62,7 @@ Distributions = "0.25.77"
6262
DistributionsAD = "0.6"
6363
DocStringExtensions = "0.8, 0.9"
6464
DynamicHMC = "3.4"
65-
DynamicPPL = "0.35"
65+
DynamicPPL = "0.36"
6666
EllipticalSliceSampling = "0.5, 1, 2"
6767
ForwardDiff = "0.10.3"
6868
Libtask = "0.8.8"

docs/src/api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu
4040
| `@model` | [`DynamicPPL.@model`](@extref) | Define a probabilistic model |
4141
| `@varname` | [`AbstractPPL.@varname`](@extref) | Generate a `VarName` from a Julia expression |
4242
| `to_submodel` | [`DynamicPPL.to_submodel`](@extref) | Define a submodel |
43-
| `prefix` | [`DynamicPPL.prefix`](@extref) | Prefix all variable names in a model with a given symbol |
43+
| `prefix` | [`DynamicPPL.prefix`](@extref) | Prefix all variable names in a model with a given VarName |
4444
| `LogDensityFunction` | [`DynamicPPL.LogDensityFunction`](@extref) | A struct containing all information about how to evaluate a model. Mostly for advanced users |
4545

4646
### Inference

src/mcmc/Inference.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@ using ..Essential
44
using DynamicPPL:
55
Metadata,
66
VarInfo,
7-
TypedVarInfo,
87
# TODO(mhauru) all_varnames_grouped_by_symbol isn't exported by DPPL, because it is only
9-
# implemented for TypedVarInfo. It is used by mh.jl. Either refactor mh.jl to not use it
8+
# implemented for NTVarInfo. It is used by mh.jl. Either refactor mh.jl to not use it
109
# or implement it for other VarInfo types and export it from DPPL.
1110
all_varnames_grouped_by_symbol,
1211
syms,
@@ -161,7 +160,7 @@ function externalsampler(
161160
end
162161

163162
# TODO: make a nicer `set_namedtuple!` and move these functions to DynamicPPL.
164-
function DynamicPPL.unflatten(vi::TypedVarInfo, θ::NamedTuple)
163+
function DynamicPPL.unflatten(vi::DynamicPPL.NTVarInfo, θ::NamedTuple)
165164
set_namedtuple!(deepcopy(vi), θ)
166165
return vi
167166
end

src/mcmc/gibbs.jl

Lines changed: 62 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ isgibbscomponent(spl::ExternalSampler) = isgibbscomponent(spl.sampler)
2121
isgibbscomponent(::AdvancedHMC.HMC) = true
2222
isgibbscomponent(::AdvancedMH.MetropolisHastings) = true
2323

24+
function can_be_wrapped(ctx::DynamicPPL.AbstractContext)
25+
return DynamicPPL.NodeTrait(ctx) isa DynamicPPL.IsLeaf
26+
end
27+
can_be_wrapped(ctx::DynamicPPL.PrefixContext) = can_be_wrapped(ctx.context)
28+
2429
# Basically like a `DynamicPPL.FixedContext` but
2530
# 1. Hijacks the tilde pipeline to fix variables.
2631
# 2. Computes the log-probability of the fixed variables.
@@ -54,8 +59,13 @@ for type stability of `tilde_assume`.
5459
# Fields
5560
$(FIELDS)
5661
"""
57-
struct GibbsContext{VNs,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext} <:
58-
DynamicPPL.AbstractContext
62+
struct GibbsContext{
63+
VNs<:Tuple{Vararg{VarName}},GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext
64+
} <: DynamicPPL.AbstractContext
65+
"""
66+
the VarNames being sampled
67+
"""
68+
target_varnames::VNs
5969
"""
6070
a `Ref` to the global `AbstractVarInfo` object that holds values for all variables, both
6171
those fixed and those being sampled. We use a `Ref` because this field may need to be
@@ -67,26 +77,14 @@ struct GibbsContext{VNs,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractCont
6777
"""
6878
context::Ctx
6979

70-
function GibbsContext{VNs}(global_varinfo, context) where {VNs}
71-
if !(DynamicPPL.NodeTrait(context) isa DynamicPPL.IsLeaf)
72-
error("GibbsContext can only wrap a leaf context, not a $(context).")
73-
end
74-
return new{VNs,typeof(global_varinfo),typeof(context)}(global_varinfo, context)
75-
end
76-
7780
function GibbsContext(target_varnames, global_varinfo, context)
78-
if !(DynamicPPL.NodeTrait(context) isa DynamicPPL.IsLeaf)
79-
error("GibbsContext can only wrap a leaf context, not a $(context).")
81+
if !can_be_wrapped(context)
82+
error("GibbsContext can only wrap a leaf or prefix context, not a $(context).")
8083
end
81-
if any(vn -> DynamicPPL.getoptic(vn) != identity, target_varnames)
82-
msg =
83-
"All Gibbs target variables must have identity lenses. " *
84-
"For example, you can't have `@varname(x.a[1])` as a target variable, " *
85-
"only `@varname(x)`."
86-
error(msg)
87-
end
88-
vn_sym = tuple(unique((DynamicPPL.getsym(vn) for vn in target_varnames))...)
89-
return new{vn_sym,typeof(global_varinfo),typeof(context)}(global_varinfo, context)
84+
target_varnames = tuple(target_varnames...) # Allow vectors.
85+
return new{typeof(target_varnames),typeof(global_varinfo),typeof(context)}(
86+
target_varnames, global_varinfo, context
87+
)
9088
end
9189
end
9290

@@ -96,8 +94,10 @@ end
9694

9795
DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent()
9896
DynamicPPL.childcontext(context::GibbsContext) = context.context
99-
function DynamicPPL.setchildcontext(context::GibbsContext{VNs}, childcontext) where {VNs}
100-
return GibbsContext{VNs}(Ref(context.global_varinfo[]), childcontext)
97+
function DynamicPPL.setchildcontext(context::GibbsContext, childcontext)
98+
return GibbsContext(
99+
context.target_varnames, Ref(context.global_varinfo[]), childcontext
100+
)
101101
end
102102

103103
get_global_varinfo(context::GibbsContext) = context.global_varinfo[]
@@ -129,7 +129,9 @@ function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarNa
129129
return map(Base.Fix1(get_conditioned_gibbs, context), vns)
130130
end
131131

132-
is_target_varname(::GibbsContext{VNs}, ::VarName{sym}) where {VNs,sym} = sym in VNs
132+
function is_target_varname(ctx::GibbsContext, vn::VarName)
133+
return any(Base.Fix2(subsumes, vn), ctx.target_varnames)
134+
end
133135

134136
function is_target_varname(context::GibbsContext, vns::AbstractArray{<:VarName})
135137
num_target = count(Iterators.map(Base.Fix1(is_target_varname, context), vns))
@@ -145,6 +147,37 @@ end
145147
# Tilde pipeline
146148
function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi)
147149
child_context = DynamicPPL.childcontext(context)
150+
151+
# Note that `child_context` may contain `PrefixContext`s -- in which case
152+
# we need to make sure that vn is appropriately prefixed before we handle
153+
# the `GibbsContext` behaviour below. For example, consider the following:
154+
# @model inner() = x ~ Normal()
155+
# @model outer() = a ~ to_submodel(inner())
156+
# If we run this with `Gibbs(@varname(a.x) => MH())`, then when we are
157+
# executing the submodel, the `context` will contain the `@varname(a.x)`
158+
# variable; `child_context` will contain `PrefixContext(@varname(a))`; and
159+
# `vn` will just be `@varname(x)`. If we just simply run
160+
# `is_target_varname(context, vn)`, it will return false, and everything
161+
# will be messed up.
162+
# TODO(penelopeysm): This 'problem' could be solved if we made GibbsContext a
163+
# leaf context and wrapped the PrefixContext _above_ the GibbsContext, so
164+
# that the prefixing would be handled by tilde_assume(::PrefixContext, ...)
165+
# _before_ we hit this method.
166+
# In the current state of GibbsContext, doing this would require
167+
# special-casing the way PrefixContext is used to wrap the leaf context.
168+
# This is very inconvenient because PrefixContext's behaviour is defined in
169+
# DynamicPPL, and we would basically have to create a new method in Turing
170+
# and override it for GibbsContext. Indeed, a better way to do this would
171+
# be to make GibbsContext a leaf context. In this case, we would be able to
172+
# rely on the existing behaviour of DynamicPPL.make_evaluate_args_and_kwargs
173+
# to correctly wrap the PrefixContext around the GibbsContext. This is very
174+
# tricky to correctly do now, but once we remove the other leaf contexts
175+
# (i.e. PriorContext and LikelihoodContext), we should be able to do this.
176+
# This is already implemented in
177+
# https://github.com/TuringLang/DynamicPPL.jl/pull/885/ but not yet
178+
# released. Exciting!
179+
vn, child_context = DynamicPPL.prefix_and_strip_contexts(child_context, vn)
180+
148181
return if is_target_varname(context, vn)
149182
# Fall back to the default behavior.
150183
DynamicPPL.tilde_assume(child_context, right, vn, vi)
@@ -177,6 +210,8 @@ function DynamicPPL.tilde_assume(
177210
)
178211
# See comment in the above, rng-less version of this method for an explanation.
179212
child_context = DynamicPPL.childcontext(context)
213+
vn, child_context = DynamicPPL.prefix_and_strip_contexts(child_context, vn)
214+
180215
return if is_target_varname(context, vn)
181216
DynamicPPL.tilde_assume(rng, child_context, sampler, right, vn, vi)
182217
elseif has_conditioned_gibbs(context, vn)
@@ -232,9 +267,11 @@ end
232267
wrap_in_sampler(x::AbstractMCMC.AbstractSampler) = x
233268
wrap_in_sampler(x::InferenceAlgorithm) = DynamicPPL.Sampler(x)
234269

235-
to_varname_list(x::Union{VarName,Symbol}) = [VarName(x)]
270+
to_varname(x::VarName) = x
271+
to_varname(x::Symbol) = VarName{x}()
272+
to_varname_list(x::Union{VarName,Symbol}) = [to_varname(x)]
236273
# Any other value is assumed to be an iterable of VarNames and Symbols.
237-
to_varname_list(t) = collect(map(VarName, t))
274+
to_varname_list(t) = collect(map(to_varname, t))
238275

239276
"""
240277
Gibbs

test/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
4040

4141
[compat]
4242
AbstractMCMC = "5"
43-
AbstractPPL = "0.9, 0.10"
43+
AbstractPPL = "0.9, 0.10, 0.11"
4444
AdvancedMH = "0.6, 0.7, 0.8"
4545
AdvancedPS = "=0.6.0"
4646
AdvancedVI = "0.2"
@@ -52,7 +52,7 @@ Combinatorics = "1"
5252
Distributions = "0.25"
5353
DistributionsAD = "0.6.3"
5454
DynamicHMC = "2.1.6, 3.0"
55-
DynamicPPL = "0.35"
55+
DynamicPPL = "0.36"
5656
FiniteDifferences = "0.10.8, 0.11, 0.12"
5757
ForwardDiff = "0.10.12 - 0.10.32, 0.10"
5858
HypothesisTests = "0.11"

0 commit comments

Comments
 (0)