@@ -664,184 +664,3 @@ function fixed(context::FixedContext)
664
664
# precedence over decendants of `context`.
665
665
return merge (context. values, fixed (childcontext (context)))
666
666
end
667
-
668
- """
669
- ValuesAsInModelContext
670
-
671
- A context that is used by [`values_as_in_model`](@ref) to obtain values
672
- of the model parameters as they are in the model.
673
-
674
- This is particularly useful when working in unconstrained space, but one
675
- wants to extract the realization of a model in a constrained space.
676
-
677
- # Fields
678
- $(TYPEDFIELDS)
679
- """
680
- struct ValuesAsInModelContext{T,C<: AbstractContext } <: AbstractContext
681
- " values that are extracted from the model"
682
- values:: T
683
- " child context"
684
- context:: C
685
- end
686
-
687
- ValuesAsInModelContext (values) = ValuesAsInModelContext (values, DefaultContext ())
688
- function ValuesAsInModelContext (context:: AbstractContext )
689
- return ValuesAsInModelContext (OrderedDict (), context)
690
- end
691
-
692
- NodeTrait (:: ValuesAsInModelContext ) = IsParent ()
693
- childcontext (context:: ValuesAsInModelContext ) = context. context
694
- function setchildcontext (context:: ValuesAsInModelContext , child)
695
- return ValuesAsInModelContext (context. values, child)
696
- end
697
-
698
- function Base. push! (context:: ValuesAsInModelContext , vn:: VarName , value)
699
- return setindex! (context. values, copy (value), vn)
700
- end
701
-
702
- function broadcast_push! (context:: ValuesAsInModelContext , vns, values)
703
- return push! .((context,), vns, values)
704
- end
705
-
706
- # This will be hit if we're broadcasting an `AbstractMatrix` over a `MultivariateDistribution`.
707
- function broadcast_push! (
708
- context:: ValuesAsInModelContext , vns:: AbstractVector , values:: AbstractMatrix
709
- )
710
- for (vn, col) in zip (vns, eachcol (values))
711
- push! (context, vn, col)
712
- end
713
- end
714
-
715
- # `tilde_asssume`
716
- function tilde_assume (context:: ValuesAsInModelContext , right, vn, vi)
717
- value, logp, vi = tilde_assume (childcontext (context), right, vn, vi)
718
- # Save the value.
719
- push! (context, vn, value)
720
- # Save the value.
721
- # Pass on.
722
- return value, logp, vi
723
- end
724
- function tilde_assume (
725
- rng:: Random.AbstractRNG , context:: ValuesAsInModelContext , sampler, right, vn, vi
726
- )
727
- value, logp, vi = tilde_assume (rng, childcontext (context), sampler, right, vn, vi)
728
- # Save the value.
729
- push! (context, vn, value)
730
- # Pass on.
731
- return value, logp, vi
732
- end
733
-
734
- # `dot_tilde_assume`
735
- function dot_tilde_assume (context:: ValuesAsInModelContext , right, left, vn, vi)
736
- value, logp, vi = dot_tilde_assume (childcontext (context), right, left, vn, vi)
737
-
738
- # Save the value.
739
- _right, _left, _vns = unwrap_right_left_vns (right, var, vn)
740
- broadcast_push! (context, _vns, value)
741
-
742
- return value, logp, vi
743
- end
744
- function dot_tilde_assume (
745
- rng:: Random.AbstractRNG , context:: ValuesAsInModelContext , sampler, right, left, vn, vi
746
- )
747
- value, logp, vi = dot_tilde_assume (
748
- rng, childcontext (context), sampler, right, left, vn, vi
749
- )
750
- # Save the value.
751
- _right, _left, _vns = unwrap_right_left_vns (right, left, vn)
752
- broadcast_push! (context, _vns, value)
753
-
754
- return value, logp, vi
755
- end
756
-
757
- """
758
- values_as_in_model(model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
759
- values_as_in_model(rng::Random.AbstractRNG, model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
760
-
761
- Get the values of `varinfo` as they would be seen in the model.
762
-
763
- If no `varinfo` is provided, then this is effectively the same as
764
- [`Base.rand(rng::Random.AbstractRNG, model::Model)`](@ref).
765
-
766
- More specifically, this method attempts to extract the realization _as seen in the model_.
767
- For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a realization compatible
768
- with `truncated(Normal(); lower=0)` regardless of whether `varinfo` is working in unconstrained
769
- space.
770
-
771
- Hence this method is a "safe" way of obtaining realizations in constrained space at the cost
772
- of additional model evaluations.
773
-
774
- # Arguments
775
- - `model::Model`: model to extract realizations from.
776
- - `varinfo::AbstractVarInfo`: variable information to use for the extraction.
777
- - `context::AbstractContext`: context to use for the extraction. If `rng` is specified, then `context`
778
- will be wrapped in a [`SamplingContext`](@ref) with the provided `rng`.
779
-
780
- # Examples
781
-
782
- ## When `VarInfo` fails
783
-
784
- The following demonstrates a common pitfall when working with [`VarInfo`](@ref) and constrained variables.
785
-
786
- ```jldoctest
787
- julia> using Distributions, StableRNGs
788
-
789
- julia> rng = StableRNG(42);
790
-
791
- julia> @model function model_changing_support()
792
- x ~ Bernoulli(0.5)
793
- y ~ x == 1 ? Uniform(0, 1) : Uniform(11, 12)
794
- end;
795
-
796
- julia> model = model_changing_support();
797
-
798
- julia> # Construct initial type-stable `VarInfo`.
799
- varinfo = VarInfo(rng, model);
800
-
801
- julia> # Link it so it works in unconstrained space.
802
- varinfo_linked = DynamicPPL.link(varinfo, model);
803
-
804
- julia> # Perform computations in unconstrained space, e.g. changing the values of `θ`.
805
- # Flip `x` so we hit the other support of `y`.
806
- θ = [!varinfo[@varname(x)], rand(rng)];
807
-
808
- julia> # Update the `VarInfo` with the new values.
809
- varinfo_linked = DynamicPPL.unflatten(varinfo_linked, θ);
810
-
811
- julia> # Determine the expected support of `y`.
812
- lb, ub = θ[1] == 1 ? (0, 1) : (11, 12)
813
- (0, 1)
814
-
815
- julia> # Approach 1: Convert back to constrained space using `invlink` and extract.
816
- varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, model);
817
-
818
- julia> # (×) Fails! Because `VarInfo` _saves_ the original distributions
819
- # used in the very first model evaluation, hence the support of `y`
820
- # is not updated even though `x` has changed.
821
- lb ≤ varinfo_invlinked[@varname(y)] ≤ ub
822
- false
823
-
824
- julia> # Approach 2: Extract realizations using `values_as_in_model`.
825
- # (✓) `values_as_in_model` will re-run the model and extract
826
- # the correct realization of `y` given the new values of `x`.
827
- lb ≤ values_as_in_model(model, varinfo_linked)[@varname(y)] ≤ ub
828
- true
829
- ```
830
- """
831
- function values_as_in_model (
832
- model:: Model ,
833
- varinfo:: AbstractVarInfo = VarInfo (),
834
- context:: AbstractContext = DefaultContext (),
835
- )
836
- context = ValuesAsInModelContext (context)
837
- evaluate!! (model, varinfo, context)
838
- return context. values
839
- end
840
- function values_as_in_model (
841
- rng:: Random.AbstractRNG ,
842
- model:: Model ,
843
- varinfo:: AbstractVarInfo = VarInfo (),
844
- context:: AbstractContext = DefaultContext (),
845
- )
846
- return values_as_in_model (model, varinfo, SamplingContext (rng, context))
847
- end
0 commit comments