Skip to content

Commit af64356

Browse files
committed
moved impl of values_as_in_model to separate file due to size of impl
1 parent 7f0ff38 commit af64356

File tree

3 files changed

+182
-181
lines changed

3 files changed

+182
-181
lines changed

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ include("transforming.jl")
179179
include("logdensityfunction.jl")
180180
include("model_utils.jl")
181181
include("extract_priors.jl")
182+
include("values_as_in_model.jl")
182183

183184
if !isdefined(Base, :get_extension)
184185
using Requires

src/contexts.jl

Lines changed: 0 additions & 181 deletions
Original file line numberDiff line numberDiff line change
@@ -664,184 +664,3 @@ function fixed(context::FixedContext)
664664
# precedence over decendants of `context`.
665665
return merge(context.values, fixed(childcontext(context)))
666666
end
667-
668-
"""
669-
ValuesAsInModelContext
670-
671-
A context that is used by [`values_as_in_model`](@ref) to obtain values
672-
of the model parameters as they are in the model.
673-
674-
This is particularly useful when working in unconstrained space, but one
675-
wants to extract the realization of a model in a constrained space.
676-
677-
# Fields
678-
$(TYPEDFIELDS)
679-
"""
680-
struct ValuesAsInModelContext{T,C<:AbstractContext} <: AbstractContext
681-
"values that are extracted from the model"
682-
values::T
683-
"child context"
684-
context::C
685-
end
686-
687-
ValuesAsInModelContext(values) = ValuesAsInModelContext(values, DefaultContext())
688-
function ValuesAsInModelContext(context::AbstractContext)
689-
return ValuesAsInModelContext(OrderedDict(), context)
690-
end
691-
692-
NodeTrait(::ValuesAsInModelContext) = IsParent()
693-
childcontext(context::ValuesAsInModelContext) = context.context
694-
function setchildcontext(context::ValuesAsInModelContext, child)
695-
return ValuesAsInModelContext(context.values, child)
696-
end
697-
698-
function Base.push!(context::ValuesAsInModelContext, vn::VarName, value)
699-
return setindex!(context.values, copy(value), vn)
700-
end
701-
702-
function broadcast_push!(context::ValuesAsInModelContext, vns, values)
703-
return push!.((context,), vns, values)
704-
end
705-
706-
# This will be hit if we're broadcasting an `AbstractMatrix` over a `MultivariateDistribution`.
707-
function broadcast_push!(
708-
context::ValuesAsInModelContext, vns::AbstractVector, values::AbstractMatrix
709-
)
710-
for (vn, col) in zip(vns, eachcol(values))
711-
push!(context, vn, col)
712-
end
713-
end
714-
715-
# `tilde_asssume`
716-
function tilde_assume(context::ValuesAsInModelContext, right, vn, vi)
717-
value, logp, vi = tilde_assume(childcontext(context), right, vn, vi)
718-
# Save the value.
719-
push!(context, vn, value)
720-
# Save the value.
721-
# Pass on.
722-
return value, logp, vi
723-
end
724-
function tilde_assume(
725-
rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi
726-
)
727-
value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
728-
# Save the value.
729-
push!(context, vn, value)
730-
# Pass on.
731-
return value, logp, vi
732-
end
733-
734-
# `dot_tilde_assume`
735-
function dot_tilde_assume(context::ValuesAsInModelContext, right, left, vn, vi)
736-
value, logp, vi = dot_tilde_assume(childcontext(context), right, left, vn, vi)
737-
738-
# Save the value.
739-
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
740-
broadcast_push!(context, _vns, value)
741-
742-
return value, logp, vi
743-
end
744-
function dot_tilde_assume(
745-
rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, left, vn, vi
746-
)
747-
value, logp, vi = dot_tilde_assume(
748-
rng, childcontext(context), sampler, right, left, vn, vi
749-
)
750-
# Save the value.
751-
_right, _left, _vns = unwrap_right_left_vns(right, left, vn)
752-
broadcast_push!(context, _vns, value)
753-
754-
return value, logp, vi
755-
end
756-
757-
"""
758-
values_as_in_model(model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
759-
values_as_in_model(rng::Random.AbstractRNG, model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
760-
761-
Get the values of `varinfo` as they would be seen in the model.
762-
763-
If no `varinfo` is provided, then this is effectively the same as
764-
[`Base.rand(rng::Random.AbstractRNG, model::Model)`](@ref).
765-
766-
More specifically, this method attempts to extract the realization _as seen in the model_.
767-
For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a realization compatible
768-
with `truncated(Normal(); lower=0)` regardless of whether `varinfo` is working in unconstrained
769-
space.
770-
771-
Hence this method is a "safe" way of obtaining realizations in constrained space at the cost
772-
of additional model evaluations.
773-
774-
# Arguments
775-
- `model::Model`: model to extract realizations from.
776-
- `varinfo::AbstractVarInfo`: variable information to use for the extraction.
777-
- `context::AbstractContext`: context to use for the extraction. If `rng` is specified, then `context`
778-
will be wrapped in a [`SamplingContext`](@ref) with the provided `rng`.
779-
780-
# Examples
781-
782-
## When `VarInfo` fails
783-
784-
The following demonstrates a common pitfall when working with [`VarInfo`](@ref) and constrained variables.
785-
786-
```jldoctest
787-
julia> using Distributions, StableRNGs
788-
789-
julia> rng = StableRNG(42);
790-
791-
julia> @model function model_changing_support()
792-
x ~ Bernoulli(0.5)
793-
y ~ x == 1 ? Uniform(0, 1) : Uniform(11, 12)
794-
end;
795-
796-
julia> model = model_changing_support();
797-
798-
julia> # Construct initial type-stable `VarInfo`.
799-
varinfo = VarInfo(rng, model);
800-
801-
julia> # Link it so it works in unconstrained space.
802-
varinfo_linked = DynamicPPL.link(varinfo, model);
803-
804-
julia> # Perform computations in unconstrained space, e.g. changing the values of `θ`.
805-
# Flip `x` so we hit the other support of `y`.
806-
θ = [!varinfo[@varname(x)], rand(rng)];
807-
808-
julia> # Update the `VarInfo` with the new values.
809-
varinfo_linked = DynamicPPL.unflatten(varinfo_linked, θ);
810-
811-
julia> # Determine the expected support of `y`.
812-
lb, ub = θ[1] == 1 ? (0, 1) : (11, 12)
813-
(0, 1)
814-
815-
julia> # Approach 1: Convert back to constrained space using `invlink` and extract.
816-
varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, model);
817-
818-
julia> # (×) Fails! Because `VarInfo` _saves_ the original distributions
819-
# used in the very first model evaluation, hence the support of `y`
820-
# is not updated even though `x` has changed.
821-
lb ≤ varinfo_invlinked[@varname(y)] ≤ ub
822-
false
823-
824-
julia> # Approach 2: Extract realizations using `values_as_in_model`.
825-
# (✓) `values_as_in_model` will re-run the model and extract
826-
# the correct realization of `y` given the new values of `x`.
827-
lb ≤ values_as_in_model(model, varinfo_linked)[@varname(y)] ≤ ub
828-
true
829-
```
830-
"""
831-
function values_as_in_model(
832-
model::Model,
833-
varinfo::AbstractVarInfo=VarInfo(),
834-
context::AbstractContext=DefaultContext(),
835-
)
836-
context = ValuesAsInModelContext(context)
837-
evaluate!!(model, varinfo, context)
838-
return context.values
839-
end
840-
function values_as_in_model(
841-
rng::Random.AbstractRNG,
842-
model::Model,
843-
varinfo::AbstractVarInfo=VarInfo(),
844-
context::AbstractContext=DefaultContext(),
845-
)
846-
return values_as_in_model(model, varinfo, SamplingContext(rng, context))
847-
end

src/values_as_in_model.jl

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
2+
"""
3+
ValuesAsInModelContext
4+
5+
A context that is used by [`values_as_in_model`](@ref) to obtain values
6+
of the model parameters as they are in the model.
7+
8+
This is particularly useful when working in unconstrained space, but one
9+
wants to extract the realization of a model in a constrained space.
10+
11+
# Fields
12+
$(TYPEDFIELDS)
13+
"""
14+
struct ValuesAsInModelContext{T,C<:AbstractContext} <: AbstractContext
15+
"values that are extracted from the model"
16+
values::T
17+
"child context"
18+
context::C
19+
end
20+
21+
ValuesAsInModelContext(values) = ValuesAsInModelContext(values, DefaultContext())
22+
function ValuesAsInModelContext(context::AbstractContext)
23+
return ValuesAsInModelContext(OrderedDict(), context)
24+
end
25+
26+
NodeTrait(::ValuesAsInModelContext) = IsParent()
27+
childcontext(context::ValuesAsInModelContext) = context.context
28+
function setchildcontext(context::ValuesAsInModelContext, child)
29+
return ValuesAsInModelContext(context.values, child)
30+
end
31+
32+
function Base.push!(context::ValuesAsInModelContext, vn::VarName, value)
33+
return setindex!(context.values, copy(value), vn)
34+
end
35+
36+
function broadcast_push!(context::ValuesAsInModelContext, vns, values)
37+
return push!.((context,), vns, values)
38+
end
39+
40+
# This will be hit if we're broadcasting an `AbstractMatrix` over a `MultivariateDistribution`.
41+
function broadcast_push!(
42+
context::ValuesAsInModelContext, vns::AbstractVector, values::AbstractMatrix
43+
)
44+
for (vn, col) in zip(vns, eachcol(values))
45+
push!(context, vn, col)
46+
end
47+
end
48+
49+
# `tilde_asssume`
50+
function tilde_assume(context::ValuesAsInModelContext, right, vn, vi)
51+
value, logp, vi = tilde_assume(childcontext(context), right, vn, vi)
52+
# Save the value.
53+
push!(context, vn, value)
54+
# Save the value.
55+
# Pass on.
56+
return value, logp, vi
57+
end
58+
function tilde_assume(
59+
rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi
60+
)
61+
value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
62+
# Save the value.
63+
push!(context, vn, value)
64+
# Pass on.
65+
return value, logp, vi
66+
end
67+
68+
# `dot_tilde_assume`
69+
function dot_tilde_assume(context::ValuesAsInModelContext, right, left, vn, vi)
70+
value, logp, vi = dot_tilde_assume(childcontext(context), right, left, vn, vi)
71+
72+
# Save the value.
73+
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
74+
broadcast_push!(context, _vns, value)
75+
76+
return value, logp, vi
77+
end
78+
function dot_tilde_assume(
79+
rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, left, vn, vi
80+
)
81+
value, logp, vi = dot_tilde_assume(
82+
rng, childcontext(context), sampler, right, left, vn, vi
83+
)
84+
# Save the value.
85+
_right, _left, _vns = unwrap_right_left_vns(right, left, vn)
86+
broadcast_push!(context, _vns, value)
87+
88+
return value, logp, vi
89+
end
90+
91+
"""
92+
values_as_in_model(model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
93+
values_as_in_model(rng::Random.AbstractRNG, model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
94+
95+
Get the values of `varinfo` as they would be seen in the model.
96+
97+
If no `varinfo` is provided, then this is effectively the same as
98+
[`Base.rand(rng::Random.AbstractRNG, model::Model)`](@ref).
99+
100+
More specifically, this method attempts to extract the realization _as seen in the model_.
101+
For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a realization compatible
102+
with `truncated(Normal(); lower=0)` regardless of whether `varinfo` is working in unconstrained
103+
space.
104+
105+
Hence this method is a "safe" way of obtaining realizations in constrained space at the cost
106+
of additional model evaluations.
107+
108+
# Arguments
109+
- `model::Model`: model to extract realizations from.
110+
- `varinfo::AbstractVarInfo`: variable information to use for the extraction.
111+
- `context::AbstractContext`: context to use for the extraction. If `rng` is specified, then `context`
112+
will be wrapped in a [`SamplingContext`](@ref) with the provided `rng`.
113+
114+
# Examples
115+
116+
## When `VarInfo` fails
117+
118+
The following demonstrates a common pitfall when working with [`VarInfo`](@ref) and constrained variables.
119+
120+
```jldoctest
121+
julia> using Distributions, StableRNGs
122+
123+
julia> rng = StableRNG(42);
124+
125+
julia> @model function model_changing_support()
126+
x ~ Bernoulli(0.5)
127+
y ~ x == 1 ? Uniform(0, 1) : Uniform(11, 12)
128+
end;
129+
130+
julia> model = model_changing_support();
131+
132+
julia> # Construct initial type-stable `VarInfo`.
133+
varinfo = VarInfo(rng, model);
134+
135+
julia> # Link it so it works in unconstrained space.
136+
varinfo_linked = DynamicPPL.link(varinfo, model);
137+
138+
julia> # Perform computations in unconstrained space, e.g. changing the values of `θ`.
139+
# Flip `x` so we hit the other support of `y`.
140+
θ = [!varinfo[@varname(x)], rand(rng)];
141+
142+
julia> # Update the `VarInfo` with the new values.
143+
varinfo_linked = DynamicPPL.unflatten(varinfo_linked, θ);
144+
145+
julia> # Determine the expected support of `y`.
146+
lb, ub = θ[1] == 1 ? (0, 1) : (11, 12)
147+
(0, 1)
148+
149+
julia> # Approach 1: Convert back to constrained space using `invlink` and extract.
150+
varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, model);
151+
152+
julia> # (×) Fails! Because `VarInfo` _saves_ the original distributions
153+
# used in the very first model evaluation, hence the support of `y`
154+
# is not updated even though `x` has changed.
155+
lb ≤ varinfo_invlinked[@varname(y)] ≤ ub
156+
false
157+
158+
julia> # Approach 2: Extract realizations using `values_as_in_model`.
159+
# (✓) `values_as_in_model` will re-run the model and extract
160+
# the correct realization of `y` given the new values of `x`.
161+
lb ≤ values_as_in_model(model, varinfo_linked)[@varname(y)] ≤ ub
162+
true
163+
```
164+
"""
165+
function values_as_in_model(
166+
model::Model,
167+
varinfo::AbstractVarInfo=VarInfo(),
168+
context::AbstractContext=DefaultContext(),
169+
)
170+
context = ValuesAsInModelContext(context)
171+
evaluate!!(model, varinfo, context)
172+
return context.values
173+
end
174+
function values_as_in_model(
175+
rng::Random.AbstractRNG,
176+
model::Model,
177+
varinfo::AbstractVarInfo=VarInfo(),
178+
context::AbstractContext=DefaultContext(),
179+
)
180+
return values_as_in_model(model, varinfo, SamplingContext(rng, context))
181+
end

0 commit comments

Comments
 (0)