From c8cee860293040a37df9c9bf84482fe01ca55923 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 12 Jun 2025 15:50:52 +0100 Subject: [PATCH 01/10] Change `evaluate!!` API, add `sample!!` --- src/model.jl | 169 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 99 insertions(+), 70 deletions(-) diff --git a/src/model.jl b/src/model.jl index 3b93fa14d..602aafef8 100644 --- a/src/model.jl +++ b/src/model.jl @@ -794,15 +794,26 @@ julia> # Now `a.x` will be sampled. fixed(model::Model) = fixed(model.context) """ - (model::Model)([rng, varinfo, sampler, context]) + (model::Model)() + (model::Model)(rng[, varinfo, sampler, context]) -Sample from the `model` using the `sampler` with random number generator `rng` and the -`context`, and store the sample and log joint probability in `varinfo`. +Sample from the `model` using the `sampler` with random number generator `rng` +and the `context`, and store the sample and log joint probability in `varinfo`. -The method resets the log joint probability of `varinfo` and increases the evaluation -number of `sampler`. +Returns the model's return value. + +If no arguments are provided, uses the default random number generator and +samples from the prior. """ -(model::Model)(args...) = first(evaluate!!(model, args...)) +(model::Model)() = model(Random.default_rng()) +function (model::Model)( + rng::AbstractRNG, + varinfo::AbstractVarInfo=VarInfo(), + sampler::AbstractSampler=SampleFromPrior(), +) + spl_ctx = SamplingContext(rng, sampler, DefaultContext()) + return evaluate!!(model, varinfo, spl_ctx) +end """ use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) @@ -815,65 +826,51 @@ function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) end """ - evaluate!!(model::Model[, rng, varinfo, sampler, context]) - -Sample from the `model` using the `sampler` with random number generator `rng` and the -`context`, and store the sample and log joint probability in `varinfo`. + sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo) -Returns both the return-value of the original model, and the resulting varinfo. +Evaluate the `model` with the given `varinfo`, but perform sampling during the +evaluation by wrapping the model's context in a `SamplingContext`. -The method resets the log joint probability of `varinfo` and increases the evaluation -number of `sampler`. +Returns a tuple of the model's return value, plus the updated `varinfo` object. """ -function AbstractPPL.evaluate!!( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext -) - return if use_threadsafe_eval(context, varinfo) - evaluate_threadsafe!!(model, varinfo, context) - else - evaluate_threadunsafe!!(model, varinfo, context) - end +function sample!!(rng::AbstractRNG, model::Model, varinfo::AbstractVarInfo) + sampling_model = contextualize( + model, SamplingContext(rng, SampleFromPrior(), model.context) + ) + return evaluate!!(sampling_model, varinfo) end -function AbstractPPL.evaluate!!( - model::Model, - rng::Random.AbstractRNG, - varinfo::AbstractVarInfo=VarInfo(), - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - return evaluate!!(model, varinfo, SamplingContext(rng, sampler, context)) -end +""" + evaluate!!(model::Model, varinfo) + evaluate!!(model::Model, varinfo, context) -function AbstractPPL.evaluate!!(model::Model, context::AbstractContext) - return evaluate!!(model, VarInfo(), context) -end +Evaluate the `model` with the given `varinfo`. If an extra context stack is +provided, the model's context is inserted into that context stack. See +[`combine_model_and_external_contexts`](@ref). -function AbstractPPL.evaluate!!( - model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}... -) - return evaluate!!(model, Random.default_rng(), args...) -end +If multiple threads are available, the varinfo provided will be wrapped in a +[`DynamicPPL.ThreadSafeVarInfo`](@ref) before evaluation. -# without VarInfo -function AbstractPPL.evaluate!!( - model::Model, - rng::Random.AbstractRNG, - sampler::AbstractSampler, - args::AbstractContext..., -) - return evaluate!!(model, rng, VarInfo(), sampler, args...) +Returns a tuple of the model's return value, plus the updated `varinfo` +(unwrapped if necessary). +""" +function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) + return if use_threadsafe_eval(model.context, varinfo) + evaluate_threadsafe!!(model, varinfo) + else + evaluate_threadunsafe!!(model, varinfo) + end end - -# without VarInfo and without AbstractSampler function AbstractPPL.evaluate!!( - model::Model, rng::Random.AbstractRNG, context::AbstractContext + model::Model, varinfo::AbstractVarInfo, context::AbstractContext ) - return evaluate!!(model, rng, VarInfo(), SampleFromPrior(), context) + new_ctx = combine_model_and_external_contexts(model.context, context) + model = contextualize(model, new_ctx) + return evaluate!!(model, varinfo) end """ - evaluate_threadunsafe!!(model, varinfo, context) + evaluate_threadunsafe!!(model, varinfo) Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. @@ -882,8 +879,8 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe!!`](@ref) """ -function evaluate_threadunsafe!!(model, varinfo, context) - return _evaluate!!(model, resetlogp!!(varinfo), context) +function evaluate_threadunsafe!!(model, varinfo) + return _evaluate!!(model, resetlogp!!(varinfo)) end """ @@ -897,31 +894,74 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe!!`](@ref) """ -function evaluate_threadsafe!!(model, varinfo, context) +function evaluate_threadsafe!!(model, varinfo) wrapper = ThreadSafeVarInfo(resetlogp!!(varinfo)) - result, wrapper_new = _evaluate!!(model, wrapper, context) + result, wrapper_new = _evaluate!!(model, wrapper) + # TODO(penelopeysm): If seems that if you pass a TSVI to this method, it + # will return the underlying VI, which is a bit counterintuitive (because + # calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it + # again). return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new)) end """ + _evaluate!!(model::Model, varinfo) _evaluate!!(model::Model, varinfo, context) -Evaluate the `model` with the arguments matching the given `context` and `varinfo` object. +Evaluate the `model` with the given `varinfo`. If an additional `context` is provided, +the model's context is combined with that context. + +This function does not wrap the varinfo in a `ThreadSafeVarInfo`. """ -function _evaluate!!(model::Model, varinfo::AbstractVarInfo, context::AbstractContext) - args, kwargs = make_evaluate_args_and_kwargs(model, varinfo, context) +function _evaluate!!(model::Model, varinfo::AbstractVarInfo) + args, kwargs = make_evaluate_args_and_kwargs(model, varinfo) return model.f(args...; kwargs...) end +function _evaluate!!(model::Model, varinfo::AbstractVarInfo, context::AbstractContext) + # TODO(penelopeysm): We don't really need this, but it's a useful + # convenience method. We could remove it after we get rid of the + # evaluate_threadsafe!! stuff (in favour of making users call evaluate!! + # with a TSVI themselves). + new_ctx = combine_model_and_external_contexts(model.context, context) + model = contextualize(model, new_ctx) + return _evaluate!!(model, varinfo) +end is_splat_symbol(s::Symbol) = startswith(string(s), "#splat#") +""" + combine_model_and_external_contexts(model_context, external_context) + +Combine a context from a model and an external context into a single context. + +The resulting context stack has the following structure: + + `external_context` -> `childcontext(external_context)` -> ... -> + `model_context` -> `childcontext(model_context)` -> ... -> + `leafcontext(external_context)` + +The reason for this is that we want to give `external_context` precedence over +`model_context`, while also preserving the leaf context of `external_context`. +We can do this by + +1. Set the leaf context of `model_context` to `leafcontext(external_context)`. +2. Set leaf context of `external_context` to the context resulting from (1). +""" +function combine_model_and_external_contexts( + model_context::AbstractContext, external_context::AbstractContext +) + return setleafcontext( + external_context, setleafcontext(model_context, leafcontext(external_context)) + ) +end + """ make_evaluate_args_and_kwargs(model, varinfo, context) Return the arguments and keyword arguments to be passed to the evaluator of the model, i.e. `model.f`e. """ @generated function make_evaluate_args_and_kwargs( - model::Model{_F,argnames}, varinfo::AbstractVarInfo, context::AbstractContext + model::Model{_F,argnames}, varinfo::AbstractVarInfo ) where {_F,argnames} unwrap_args = [ if is_splat_symbol(var) @@ -930,18 +970,7 @@ Return the arguments and keyword arguments to be passed to the evaluator of the :($matchingvalue(varinfo, model.args.$var)) end for var in argnames ] - - # We want to give `context` precedence over `model.context` while also - # preserving the leaf context of `context`. We can do this by - # 1. Set the leaf context of `model.context` to `leafcontext(context)`. - # 2. Set leaf context of `context` to the context resulting from (1). - # The result is: - # `context` -> `childcontext(context)` -> ... -> `model.context` - # -> `childcontext(model.context)` -> ... -> `leafcontext(context)` return quote - context_new = setleafcontext( - context, setleafcontext(model.context, leafcontext(context)) - ) args = ( model, # Maybe perform `invlink!!` once prior to evaluation to avoid From ec02e504136957c5dfe75c6e1ee254cd708e15e2 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 12 Jun 2025 15:55:12 +0100 Subject: [PATCH 02/10] Fix literally everything else that I broke --- benchmarks/src/DynamicPPLBenchmarks.jl | 3 +- docs/src/api.md | 18 ++- ext/DynamicPPLJETExt.jl | 22 ++-- ext/DynamicPPLMCMCChainsExt.jl | 2 +- src/DynamicPPL.jl | 1 + src/compiler.jl | 30 ++--- src/debug_utils.jl | 53 +++------ src/experimental.jl | 16 +-- src/extract_priors.jl | 5 +- src/logdensityfunction.jl | 99 ++++++----------- src/model.jl | 86 +++++++------- src/pointwise_logdensities.jl | 59 ++++------ src/sampler.jl | 17 +-- src/simple_varinfo.jl | 49 ++++---- src/submodel_macro.jl | 30 ++--- src/test_utils/ad.jl | 21 +--- src/test_utils/model_interface.jl | 5 +- src/test_utils/varinfo.jl | 2 +- src/threadsafe.jl | 11 +- src/transforming.jl | 9 +- src/values_as_in_model.jl | 14 +-- src/varinfo.jl | 148 +++++-------------------- test/ad.jl | 5 +- test/compiler.jl | 23 ++-- test/context_implementations.jl | 4 +- test/contexts.jl | 3 +- test/debug_utils.jl | 18 +-- test/ext/DynamicPPLJETExt.jl | 5 +- test/linking.jl | 2 +- test/model.jl | 21 ++-- test/simple_varinfo.jl | 16 ++- test/threadsafe.jl | 40 +++---- test/varinfo.jl | 24 ++-- test/varnamedvector.jl | 8 +- 34 files changed, 346 insertions(+), 523 deletions(-) diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 6f486e2f5..26ec35b65 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -81,13 +81,12 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked:: end adbackend = to_backend(adbackend) - context = DynamicPPL.DefaultContext() if islinked vi = DynamicPPL.link(vi, model) end - f = DynamicPPL.LogDensityFunction(model, vi, context; adtype=adbackend) + f = DynamicPPL.LogDensityFunction(model, vi; adtype=adbackend) # The parameters at which we evaluate f. θ = vi[:] diff --git a/docs/src/api.md b/docs/src/api.md index 8e5c64886..b867e2e64 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -36,6 +36,12 @@ getargnames getmissings ``` +The context of a model can be set using [`contextualize`](@ref): + +```@docs +contextualize +``` + ## Evaluation With [`rand`](@ref) one can draw samples from the prior distribution of a [`Model`](@ref). @@ -438,13 +444,21 @@ DynamicPPL.varname_and_value_leaves ### Evaluation Contexts -Internally, both sampling and evaluation of log densities are performed with [`AbstractPPL.evaluate!!`](@ref). +Internally, model evaluation is performed with [`AbstractPPL.evaluate!!`](@ref). ```@docs AbstractPPL.evaluate!! ``` -The behaviour of a model execution can be changed with evaluation contexts that are passed as additional argument to the model function. +This method mutates the `varinfo` used for execution. +By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`. +To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method: + +```@docs +DynamicPPL.sample!! +``` + +The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model. Contexts are subtypes of `AbstractPPL.AbstractContext`. ```@docs diff --git a/ext/DynamicPPLJETExt.jl b/ext/DynamicPPLJETExt.jl index aa95093f2..760d17bb0 100644 --- a/ext/DynamicPPLJETExt.jl +++ b/ext/DynamicPPLJETExt.jl @@ -4,15 +4,10 @@ using DynamicPPL: DynamicPPL using JET: JET function DynamicPPL.Experimental.is_suitable_varinfo( - model::DynamicPPL.Model, - context::DynamicPPL.AbstractContext, - varinfo::DynamicPPL.AbstractVarInfo; - only_ddpl::Bool=true, + model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; only_ddpl::Bool=true ) # Let's make sure that both evaluation and sampling doesn't result in type errors. - f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, varinfo, context - ) + f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(model, varinfo) # If specified, we only check errors originating somewhere in the DynamicPPL.jl. # This way we don't just fall back to untyped if the user's code is the issue. result = if only_ddpl @@ -24,14 +19,19 @@ function DynamicPPL.Experimental.is_suitable_varinfo( end function DynamicPPL.Experimental._determine_varinfo_jet( - model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_ddpl::Bool=true + model::DynamicPPL.Model; only_ddpl::Bool=true ) + # Use SamplingContext to test type stability. + sampling_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(model.context) + ) + # First we try with the typed varinfo. - varinfo = DynamicPPL.typed_varinfo(model, context) + varinfo = DynamicPPL.typed_varinfo(sampling_model) # Let's make sure that both evaluation and sampling doesn't result in type errors. issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo( - model, context, varinfo; only_ddpl + sampling_model, varinfo; only_ddpl ) if !issuccess @@ -46,7 +46,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet( else # Warn the user that we can't use the type stable one. @warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo." - DynamicPPL.untyped_varinfo(model, context) + DynamicPPL.untyped_varinfo(sampling_model) end end diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 70f0f0182..5e1c75aa5 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -115,7 +115,7 @@ function DynamicPPL.predict( iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) predictive_samples = map(iters) do (sample_idx, chain_idx) DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx) - model(rng, varinfo, DynamicPPL.SampleFromPrior()) + varinfo = last(DynamicPPL.sample!!(rng, model, varinfo)) vals = DynamicPPL.values_as_in_model(model, false, varinfo) varname_vals = mapreduce( diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 9217deb4f..4bd4f2529 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -102,6 +102,7 @@ export AbstractVarInfo, # LogDensityFunction LogDensityFunction, # Contexts + contextualize, SamplingContext, DefaultContext, PrefixContext, diff --git a/src/compiler.jl b/src/compiler.jl index b783c2a13..22dff33a2 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1,4 +1,4 @@ -const INTERNALNAMES = (:__model__, :__context__, :__varinfo__) +const INTERNALNAMES = (:__model__, :__varinfo__) """ need_concretize(expr) @@ -63,9 +63,9 @@ used in its place. function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr)) return quote if $(DynamicPPL.contextual_isassumption)( - __context__, $(DynamicPPL.prefix)(__context__, $vn) + __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) ) - # Considered an assumption by `__context__` which means either: + # Considered an assumption by `__model__.context` which means either: # 1. We hit the default implementation, e.g. using `DefaultContext`, # which in turn means that we haven't considered if it's one of # the model arguments, hence we need to check this. @@ -116,7 +116,7 @@ end isfixed(expr, vn) = false function isfixed(::Union{Symbol,Expr}, vn) return :($(DynamicPPL.contextual_isfixed)( - __context__, $(DynamicPPL.prefix)(__context__, $vn) + __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) )) end @@ -417,7 +417,7 @@ function generate_assign(left, right) return quote $right_val = $right if $(DynamicPPL.is_extracting_values)(__varinfo__) - $vn = $(DynamicPPL.prefix)(__context__, $(make_varname_expression(left))) + $vn = $(DynamicPPL.prefix)(__model__.context, $(make_varname_expression(left))) __varinfo__ = $(map_accumulator!!)( $acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel) ) @@ -431,7 +431,11 @@ function generate_tilde_literal(left, right) @gensym value return quote $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( - __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, nothing, __varinfo__ + __model__.context, + $(DynamicPPL.check_tilde_rhs)($right), + $left, + nothing, + __varinfo__, ) $value end @@ -456,7 +460,7 @@ function generate_tilde(left, right) $isassumption = $(DynamicPPL.isassumption(left, vn)) if $(DynamicPPL.isfixed(left, vn)) $left = $(DynamicPPL.getfixed_nested)( - __context__, $(DynamicPPL.prefix)(__context__, $vn) + __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) ) elseif $isassumption $(generate_tilde_assume(left, dist, vn)) @@ -464,12 +468,12 @@ function generate_tilde(left, right) # If `vn` is not in `argnames`, we need to make sure that the variable is defined. if !$(DynamicPPL.inargnames)($vn, __model__) $left = $(DynamicPPL.getconditioned_nested)( - __context__, $(DynamicPPL.prefix)(__context__, $vn) + __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) ) end $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( - __context__, + __model__.context, $(DynamicPPL.check_tilde_rhs)($dist), $(maybe_view(left)), $vn, @@ -494,7 +498,7 @@ function generate_tilde_assume(left, right, vn) return quote $value, __varinfo__ = $(DynamicPPL.tilde_assume!!)( - __context__, + __model__.context, $(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)..., __varinfo__, ) @@ -652,11 +656,7 @@ function build_output(modeldef, linenumbernode) # Add the internal arguments to the user-specified arguments (positional + keywords). evaluatordef[:args] = vcat( - [ - :(__model__::$(DynamicPPL.Model)), - :(__varinfo__::$(DynamicPPL.AbstractVarInfo)), - :(__context__::$(DynamicPPL.AbstractContext)), - ], + [:(__model__::$(DynamicPPL.Model)), :(__varinfo__::$(DynamicPPL.AbstractVarInfo))], args, ) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 54852a736..4343ce8ac 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -131,9 +131,7 @@ A context used for checking validity of a model. # Fields $(FIELDS) """ -struct DebugContext{M<:Model,C<:AbstractContext} <: AbstractContext - "model that is being run" - model::M +struct DebugContext{C<:AbstractContext} <: AbstractContext "context used for running the model" context::C "mapping from varnames to the number of times they have been seen" @@ -149,7 +147,6 @@ struct DebugContext{M<:Model,C<:AbstractContext} <: AbstractContext end function DebugContext( - model::Model, context::AbstractContext=DefaultContext(); varnames_seen=OrderedDict{VarName,Int}(), statements=Vector{Stmt}(), @@ -158,7 +155,6 @@ function DebugContext( record_varinfo=false, ) return DebugContext( - model, context, varnames_seen, statements, @@ -344,7 +340,7 @@ function check_varnames_seen(varnames_seen::AbstractDict{VarName,Int}) end # A check we run on the model before evaluating it. -function check_model_pre_evaluation(context::DebugContext, model::Model) +function check_model_pre_evaluation(model::Model) issuccess = true # If something is in the model arguments, then it should NOT be in `condition`, # nor should there be any symbol present in `condition` that has the same symbol. @@ -361,8 +357,8 @@ function check_model_pre_evaluation(context::DebugContext, model::Model) return issuccess end -function check_model_post_evaluation(context::DebugContext, model::Model) - return check_varnames_seen(context.varnames_seen) +function check_model_post_evaluation(model::Model) + return check_varnames_seen(model.context.varnames_seen) end """ @@ -438,25 +434,23 @@ function check_model_and_trace( rng::Random.AbstractRNG, model::Model; varinfo=VarInfo(), - context=SamplingContext(rng), error_on_failure=false, kwargs..., ) # Execute the model with the debug context. debug_context = DebugContext( - model, context; error_on_failure=error_on_failure, kwargs... + SamplingContext(rng, model.context); error_on_failure=error_on_failure, kwargs... ) + debug_model = DynamicPPL.contextualize(model, debug_context) # Perform checks before evaluating the model. - issuccess = check_model_pre_evaluation(debug_context, model) + issuccess = check_model_pre_evaluation(debug_model) # Force single-threaded execution. - retval, varinfo_result = DynamicPPL.evaluate_threadunsafe!!( - model, varinfo, debug_context - ) + DynamicPPL.evaluate_threadunsafe!!(debug_model, varinfo) # Perform checks after evaluating the model. - issuccess &= check_model_post_evaluation(debug_context, model) + issuccess &= check_model_post_evaluation(debug_model) if !issuccess && error_on_failure error("model check failed") @@ -535,14 +529,13 @@ function has_static_constraints( end """ - gen_evaluator_call_with_types(model[, varinfo, context]) + gen_evaluator_call_with_types(model[, varinfo]) Generate the evaluator call and the types of the arguments. # Arguments - `model::Model`: The model whose evaluator is of interest. - `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`. -- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref). # Returns A 2-tuple with the following elements: @@ -551,11 +544,9 @@ A 2-tuple with the following elements: - `argtypes::Type{<:Tuple}`: The types of the arguments for the evaluator. """ function gen_evaluator_call_with_types( - model::Model, - varinfo::AbstractVarInfo=VarInfo(model), - context::AbstractContext=DefaultContext(), + model::Model, varinfo::AbstractVarInfo=VarInfo(model) ) - args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) + args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo) return if isempty(kwargs) (model.f, Base.typesof(args...)) else @@ -564,7 +555,7 @@ function gen_evaluator_call_with_types( end """ - model_warntype(model[, varinfo, context]; optimize=true) + model_warntype(model[, varinfo]; optimize=true) Check the type stability of the model's evaluator, warning about any potential issues. @@ -573,23 +564,19 @@ This simply calls `@code_warntype` on the model's evaluator, filling in internal # Arguments - `model::Model`: The model to check. - `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`. -- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref). # Keyword Arguments - `optimize::Bool`: Whether to generate optimized code. Default: `false`. """ function model_warntype( - model::Model, - varinfo::AbstractVarInfo=VarInfo(model), - context::AbstractContext=DefaultContext(); - optimize::Bool=false, + model::Model, varinfo::AbstractVarInfo=VarInfo(model), optimize::Bool=false ) - ftype, argtypes = gen_evaluator_call_with_types(model, varinfo, context) + ftype, argtypes = gen_evaluator_call_with_types(model, varinfo) return InteractiveUtils.code_warntype(ftype, argtypes; optimize=optimize) end """ - model_typed(model[, varinfo, context]; optimize=true) + model_typed(model[, varinfo]; optimize=true) Return the type inference for the model's evaluator. @@ -598,18 +585,14 @@ This simply calls `@code_typed` on the model's evaluator, filling in internal ar # Arguments - `model::Model`: The model to check. - `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`. -- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref). # Keyword Arguments - `optimize::Bool`: Whether to generate optimized code. Default: `true`. """ function model_typed( - model::Model, - varinfo::AbstractVarInfo=VarInfo(model), - context::AbstractContext=DefaultContext(); - optimize::Bool=true, + model::Model, varinfo::AbstractVarInfo=VarInfo(model), optimize::Bool=true ) - ftype, argtypes = gen_evaluator_call_with_types(model, varinfo, context) + ftype, argtypes = gen_evaluator_call_with_types(model, varinfo) return only(InteractiveUtils.code_typed(ftype, argtypes; optimize=optimize)) end diff --git a/src/experimental.jl b/src/experimental.jl index 84038803c..974912957 100644 --- a/src/experimental.jl +++ b/src/experimental.jl @@ -4,16 +4,15 @@ using DynamicPPL: DynamicPPL # This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency. """ - is_suitable_varinfo(model::Model, context::AbstractContext, varinfo::AbstractVarInfo; kwargs...) + is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...) -Check if the `model` supports evaluation using the provided `context` and `varinfo`. +Check if the `model` supports evaluation using the provided `varinfo`. !!! warning Loading JET.jl is required before calling this function. # Arguments - `model`: The model to verify the support for. -- `context`: The context to use for the model evaluation. - `varinfo`: The varinfo to verify the support for. # Keyword Arguments @@ -29,7 +28,7 @@ function is_suitable_varinfo end function _determine_varinfo_jet end """ - determine_suitable_varinfo(model[, context]; only_ddpl::Bool=true) + determine_suitable_varinfo(model; only_ddpl::Bool=true) Return a suitable varinfo for the given `model`. @@ -41,7 +40,6 @@ See also: [`DynamicPPL.Experimental.is_suitable_varinfo`](@ref). # Arguments - `model`: The model for which to determine the varinfo. -- `context`: The context to use for the model evaluation. Default: `SamplingContext()`. # Keyword Arguments - `only_ddpl`: If `true`, only consider error reports within DynamicPPL.jl. @@ -85,14 +83,10 @@ julia> vi isa typeof(DynamicPPL.typed_varinfo(model_with_static_support())) true ``` """ -function determine_suitable_varinfo( - model::DynamicPPL.Model, - context::DynamicPPL.AbstractContext=DynamicPPL.SamplingContext(); - only_ddpl::Bool=true, -) +function determine_suitable_varinfo(model::DynamicPPL.Model; only_ddpl::Bool=true) # If JET.jl has been loaded, and thus `determine_varinfo` has been defined, we use that. return if Base.get_extension(DynamicPPL, :DynamicPPLJETExt) !== nothing - _determine_varinfo_jet(model, context; only_ddpl) + _determine_varinfo_jet(model; only_ddpl) else # Warn the user. @warn "JET.jl is not loaded. Assumes the model is compatible with typed varinfo." diff --git a/src/extract_priors.jl b/src/extract_priors.jl index 9047c9f0a..ac25bbbf8 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -116,7 +116,8 @@ function extract_priors(rng::Random.AbstractRNG, model::Model) # workaround for the fact that `order` is still hardcoded in VarInfo, and hence you # can't push new variables without knowing the num_produce. Remove this when possible. varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(), NumProduceAccumulator())) - varinfo = last(evaluate!!(model, varinfo, SamplingContext(rng))) + new_model = contextualize(model, SamplingContext(rng, model.context)) + varinfo = last(evaluate!!(new_model, varinfo)) return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end @@ -135,6 +136,6 @@ function extract_priors(model::Model, varinfo::AbstractVarInfo) varinfo = setaccs!!( deepcopy(varinfo), (PriorDistributionAccumulator(), NumProduceAccumulator()) ) - varinfo = last(evaluate!!(model, varinfo, DefaultContext())) + varinfo = last(evaluate!!(model, varinfo)) return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index c5586f80f..e489b46ba 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -18,8 +18,7 @@ is_supported(::ADTypes.AutoReverseDiff) = true """ LogDensityFunction( model::Model, - varinfo::AbstractVarInfo=VarInfo(model), - context::AbstractContext=DefaultContext(); + varinfo::AbstractVarInfo=VarInfo(model); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing ) @@ -30,9 +29,8 @@ A struct which contains a model, along with all the information necessary to: that point. At its most basic level, a LogDensityFunction wraps the model together with its -the type of varinfo to be used, as well as the evaluation context. These must -be known in order to calculate the log density (using -[`DynamicPPL.evaluate!!`](@ref)). +the type of varinfo to be used. These must be known in order to calculate the +log density (using [`DynamicPPL.evaluate!!`](@ref)). If the `adtype` keyword argument is provided, then this struct will also store the adtype along with other information for efficient calculation of the @@ -95,14 +93,12 @@ julia> LogDensityProblems.logdensity_and_gradient(f, [0.0]) ``` """ struct LogDensityFunction{ - M<:Model,V<:AbstractVarInfo,C<:AbstractContext,AD<:Union{Nothing,ADTypes.AbstractADType} + M<:Model,V<:AbstractVarInfo,AD<:Union{Nothing,ADTypes.AbstractADType} } <: AbstractModel "model used for evaluation" model::M "varinfo used for evaluation" varinfo::V - "context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable" - context::C "AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated" adtype::AD "(internal use only) gradient preparation object for the model" @@ -110,35 +106,29 @@ struct LogDensityFunction{ function LogDensityFunction( model::Model, - varinfo::AbstractVarInfo=VarInfo(model), - context::AbstractContext=leafcontext(model.context); + varinfo::AbstractVarInfo=VarInfo(model); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) if adtype === nothing prep = nothing else # Make backend-specific tweaks to the adtype - adtype = tweak_adtype(adtype, model, varinfo, context) + adtype = tweak_adtype(adtype, model, varinfo) # Check whether it is supported is_supported(adtype) || @warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed." # Get a set of dummy params to use for prep x = map(identity, varinfo[:]) if use_closure(adtype) - prep = DI.prepare_gradient(LogDensityAt(model, varinfo, context), adtype, x) + prep = DI.prepare_gradient(LogDensityAt(model, varinfo), adtype, x) else prep = DI.prepare_gradient( - logdensity_at, - adtype, - x, - DI.Constant(model), - DI.Constant(varinfo), - DI.Constant(context), + logdensity_at, adtype, x, DI.Constant(model), DI.Constant(varinfo) ) end end - return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}( - model, varinfo, context, adtype, prep + return new{typeof(model),typeof(varinfo),typeof(adtype)}( + model, varinfo, adtype, prep ) end end @@ -149,9 +139,9 @@ end adtype::Union{Nothing,ADTypes.AbstractADType} ) -Create a new LogDensityFunction using the model, varinfo, and context from the given -`ldf` argument, but with the AD type set to `adtype`. To remove the AD type, pass -`nothing` as the second argument. +Create a new LogDensityFunction using the model, and varinfo from the given +`ldf` argument, but with the AD type set to `adtype`. To remove the AD type, +pass `nothing` as the second argument. """ function LogDensityFunction( f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType} @@ -159,7 +149,7 @@ function LogDensityFunction( return if adtype === f.adtype f # Avoid recomputing prep if not needed else - LogDensityFunction(f.model, f.varinfo, f.context; adtype=adtype) + LogDensityFunction(f.model, f.varinfo; adtype=adtype) end end @@ -168,20 +158,18 @@ end x::AbstractVector, model::Model, varinfo::AbstractVarInfo, - context::AbstractContext ) Evaluate the log density of the given `model` at the given parameter values `x`, -using the given `varinfo` and `context`. Note that the `varinfo` argument is provided -only for its structure, in the sense that the parameters from the vector `x` are inserted -into it, and its own parameters are discarded. It does, however, determine whether the log -prior, likelihood, or joint is returned, based on which accumulators are set in it. +using the given `varinfo`. Note that the `varinfo` argument is provided only +for its structure, in the sense that the parameters from the vector `x` are +inserted into it, and its own parameters are discarded. It does, however, +determine whether the log prior, likelihood, or joint is returned, based on +which accumulators are set in it. """ -function logdensity_at( - x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext -) +function logdensity_at(x::AbstractVector, model::Model, varinfo::AbstractVarInfo) varinfo_new = unflatten(varinfo, x) - varinfo_eval = last(evaluate!!(model, varinfo_new, context)) + varinfo_eval = last(evaluate!!(model, varinfo_new)) has_prior = hasacc(varinfo_eval, Val(:LogPrior)) has_likelihood = hasacc(varinfo_eval, Val(:LogLikelihood)) if has_prior && has_likelihood @@ -196,60 +184,48 @@ function logdensity_at( end """ - LogDensityAt{M<:Model,V<:AbstractVarInfo,C<:AbstractContext}( + LogDensityAt{M<:Model,V<:AbstractVarInfo}( model::M varinfo::V - context::C ) A callable struct that serves the same purpose as `x -> logdensity_at(x, model, -varinfo, context)`. +varinfo)`. """ -struct LogDensityAt{M<:Model,V<:AbstractVarInfo,C<:AbstractContext} +struct LogDensityAt{M<:Model,V<:AbstractVarInfo} model::M varinfo::V - context::C -end -function (ld::LogDensityAt)(x::AbstractVector) - return logdensity_at(x, ld.model, ld.varinfo, ld.context) end +(ld::LogDensityAt)(x::AbstractVector) = logdensity_at(x, ld.model, ld.varinfo) ### LogDensityProblems interface function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,V,C,Nothing}} -) where {M,V,C} + ::Type{<:LogDensityFunction{M,V,Nothing}} +) where {M,V} return LogDensityProblems.LogDensityOrder{0}() end function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,V,C,AD}} -) where {M,V,C,AD<:ADTypes.AbstractADType} + ::Type{<:LogDensityFunction{M,V,AD}} +) where {M,V,AD<:ADTypes.AbstractADType} return LogDensityProblems.LogDensityOrder{1}() end function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector) - return logdensity_at(x, f.model, f.varinfo, f.context) + return logdensity_at(x, f.model, f.varinfo) end function LogDensityProblems.logdensity_and_gradient( - f::LogDensityFunction{M,V,C,AD}, x::AbstractVector -) where {M,V,C,AD<:ADTypes.AbstractADType} + f::LogDensityFunction{M,V,AD}, x::AbstractVector +) where {M,V,AD<:ADTypes.AbstractADType} f.prep === nothing && error("Gradient preparation not available; this should not happen") x = map(identity, x) # Concretise type # Make branching statically inferrable, i.e. type-stable (even if the two # branches happen to return different types) return if use_closure(f.adtype) - DI.value_and_gradient( - LogDensityAt(f.model, f.varinfo, f.context), f.prep, f.adtype, x - ) + DI.value_and_gradient(LogDensityAt(f.model, f.varinfo), f.prep, f.adtype, x) else DI.value_and_gradient( - logdensity_at, - f.prep, - f.adtype, - x, - DI.Constant(f.model), - DI.Constant(f.varinfo), - DI.Constant(f.context), + logdensity_at, f.prep, f.adtype, x, DI.Constant(f.model), DI.Constant(f.varinfo) ) end end @@ -264,7 +240,6 @@ LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) adtype::ADTypes.AbstractADType, model::Model, varinfo::AbstractVarInfo, - context::AbstractContext ) Return an 'optimised' form of the adtype. This is useful for doing @@ -275,9 +250,7 @@ model. By default, this just returns the input unchanged. """ -tweak_adtype( - adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo, ::AbstractContext -) = adtype +tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtype """ use_closure(adtype::ADTypes.AbstractADType) @@ -319,7 +292,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. """ function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) - return LogDensityFunction(model, f.varinfo, f.context; adtype=f.adtype) + return LogDensityFunction(model, f.varinfo; adtype=f.adtype) end """ diff --git a/src/model.jl b/src/model.jl index 602aafef8..fd275dc17 100644 --- a/src/model.jl +++ b/src/model.jl @@ -85,6 +85,12 @@ function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); k return Model(f, args, NamedTuple(kwargs), context) end +""" + contextualize(model::Model, context::AbstractContext) + +Return a new `Model` with the same evaluation function and other arguments, but +with its underlying context set to `context`. +""" function contextualize(model::Model, context::AbstractContext) return Model(model.f, model.args, model.defaults, context) end @@ -794,25 +800,22 @@ julia> # Now `a.x` will be sampled. fixed(model::Model) = fixed(model.context) """ - (model::Model)() - (model::Model)(rng[, varinfo, sampler, context]) + (model::Model)([rng, varinfo]) -Sample from the `model` using the `sampler` with random number generator `rng` -and the `context`, and store the sample and log joint probability in `varinfo`. +Sample from the prior of the `model` with random number generator `rng`. Returns the model's return value. -If no arguments are provided, uses the default random number generator and -samples from the prior. +Note that calling this with an existing `varinfo` object will mutate it. """ -(model::Model)() = model(Random.default_rng()) -function (model::Model)( - rng::AbstractRNG, - varinfo::AbstractVarInfo=VarInfo(), - sampler::AbstractSampler=SampleFromPrior(), -) - spl_ctx = SamplingContext(rng, sampler, DefaultContext()) - return evaluate!!(model, varinfo, spl_ctx) +(model::Model)() = model(Random.default_rng(), VarInfo()) +function (model::Model)(varinfo::AbstractVarInfo) + return model(Random.default_rng(), varinfo) +end +# ^ Weird Documenter.jl bug means that we have to write the two above separately +# as it can only detect the `function`-less syntax. +function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo()) + return first(sample!!(rng, model, varinfo)) end """ @@ -826,19 +829,30 @@ function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) end """ - sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo) + sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler]) Evaluate the `model` with the given `varinfo`, but perform sampling during the -evaluation by wrapping the model's context in a `SamplingContext`. +evaluation using the given `sampler` by wrapping the model's context in a +`SamplingContext`. + +If `sampler` is not provided, defaults to [`SampleFromPrior`](@ref). Returns a tuple of the model's return value, plus the updated `varinfo` object. """ -function sample!!(rng::AbstractRNG, model::Model, varinfo::AbstractVarInfo) - sampling_model = contextualize( - model, SamplingContext(rng, SampleFromPrior(), model.context) - ) +function sample!!( + rng::Random.AbstractRNG, + model::Model, + varinfo::AbstractVarInfo, + sampler::AbstractSampler=SampleFromPrior(), +) + sampling_model = contextualize(model, SamplingContext(rng, sampler, model.context)) return evaluate!!(sampling_model, varinfo) end +function sample!!( + model::Model, varinfo::AbstractVarInfo, sampler::AbstractSampler=SampleFromPrior() +) + return sample!!(Random.default_rng(), model, varinfo, sampler) +end """ evaluate!!(model::Model, varinfo) @@ -846,10 +860,10 @@ end Evaluate the `model` with the given `varinfo`. If an extra context stack is provided, the model's context is inserted into that context stack. See -[`combine_model_and_external_contexts`](@ref). +`combine_model_and_external_contexts`. If multiple threads are available, the varinfo provided will be wrapped in a -[`DynamicPPL.ThreadSafeVarInfo`](@ref) before evaluation. +`ThreadSafeVarInfo` before evaluation. Returns a tuple of the model's return value, plus the updated `varinfo` (unwrapped if necessary). @@ -864,6 +878,10 @@ end function AbstractPPL.evaluate!!( model::Model, varinfo::AbstractVarInfo, context::AbstractContext ) + Base.depwarn( + "The `context` argument to evaluate!!(model, varinfo, context) is deprecated.", + :dynamicppl_evaluate_context, + ) new_ctx = combine_model_and_external_contexts(model.context, context) model = contextualize(model, new_ctx) return evaluate!!(model, varinfo) @@ -911,7 +929,8 @@ end Evaluate the `model` with the given `varinfo`. If an additional `context` is provided, the model's context is combined with that context. -This function does not wrap the varinfo in a `ThreadSafeVarInfo`. +This function does not wrap the varinfo in a `ThreadSafeVarInfo`. It also does not +reset the log probability of the `varinfo` before running. """ function _evaluate!!(model::Model, varinfo::AbstractVarInfo) args, kwargs = make_evaluate_args_and_kwargs(model, varinfo) @@ -978,7 +997,6 @@ Return the arguments and keyword arguments to be passed to the evaluator of the # speeding up computation. See docs for `maybe_invlink_before_eval!!` # for more information. maybe_invlink_before_eval!!(varinfo, model), - context_new, $(unwrap_args...), ) kwargs = model.defaults @@ -1014,15 +1032,7 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f) Generate a sample of type `T` from the prior distribution of the `model`. """ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} - x = last( - evaluate!!( - model, - SimpleVarInfo{Float64}(OrderedDict()), - # NOTE: Use `leafcontext` here so we a) avoid overriding the leaf context of `model`, - # and b) avoid double-stacking the parent contexts. - SamplingContext(rng, SampleFromPrior(), leafcontext(model.context)), - ), - ) + x = last(sample!!(rng, model, SimpleVarInfo{Float64}(OrderedDict()))) return values_as(x, T) end @@ -1039,7 +1049,7 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logprior`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - return getlogjoint(last(evaluate!!(model, varinfo, DefaultContext()))) + return getlogjoint(last(evaluate!!(model, varinfo))) end """ @@ -1093,7 +1103,7 @@ function logprior(model::Model, varinfo::AbstractVarInfo) LogPriorAccumulator() end varinfo = setaccs!!(deepcopy(varinfo), (logprioracc,)) - return getlogprior(last(evaluate!!(model, varinfo, DefaultContext()))) + return getlogprior(last(evaluate!!(model, varinfo))) end """ @@ -1147,7 +1157,7 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) LogLikelihoodAccumulator() end varinfo = setaccs!!(deepcopy(varinfo), (loglikelihoodacc,)) - return getloglikelihood(last(evaluate!!(model, varinfo, DefaultContext()))) + return getloglikelihood(last(evaluate!!(model, varinfo))) end """ @@ -1187,7 +1197,7 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC end """ - predict([rng::AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo}) + predict([rng::Random.AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo}) Generate samples from the posterior predictive distribution by evaluating `model` at each set of parameter values provided in `chain`. The number of posterior predictive samples matches @@ -1201,7 +1211,7 @@ function predict( return map(chain) do params_varinfo vi = deepcopy(varinfo) DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple)) - model(rng, vi, SampleFromPrior()) + model(rng, vi) return vi end end diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index b6b97c8f9..59cc5e1bb 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -3,9 +3,10 @@ An accumulator that stores the log-probabilities of each variable in a model. -Internally this context stores the log-probabilities in a dictionary, where the keys are -the variable names and the values are vectors of log-probabilities. Each element in a vector -corresponds to one execution of the model. +Internally this accumulator stores the log-probabilities in a dictionary, where +the keys are the variable names and the values are vectors of +log-probabilities. Each element in a vector corresponds to one execution of the +model. `whichlogprob` is a symbol that can be `:both`, `:prior`, or `:likelihood`, and specifies which log-probabilities to store in the accumulator. `KeyType` is the type by which variable @@ -98,7 +99,6 @@ end model::Model, chain::Chains, keytype=String, - context=DefaultContext(), ::Val{whichlogprob}=Val(:both), ) @@ -107,9 +107,9 @@ with keys corresponding to symbols of the variables, and values being matrices of shape `(num_chains, num_samples)`. `keytype` specifies what the type of the keys used in the returned `OrderedDict` are. -Currently, only `String` and `VarName` are supported. `context` is the evaluation context, -and `whichlogprob` specifies which log-probabilities to compute. It can be `:both`, -`:prior`, or `:likelihood`. +Currently, only `String` and `VarName` are supported. `whichlogprob` specifies +which log-probabilities to compute. It can be `:both`, `:prior`, or +`:likelihood`. See also: [`pointwise_loglikelihoods`](@ref), [`pointwise_loglikelihoods`](@ref). @@ -211,11 +211,7 @@ julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ``` """ function pointwise_logdensities( - model::Model, - chain, - ::Type{KeyType}=String, - context::AbstractContext=DefaultContext(), - ::Val{whichlogprob}=Val(:both), + model::Model, chain, ::Type{KeyType}=String, ::Val{whichlogprob}=Val(:both) ) where {KeyType,whichlogprob} # Get the data by executing the model once vi = VarInfo(model) @@ -229,7 +225,7 @@ function pointwise_logdensities( setval!(vi, chain, sample_idx, chain_idx) # Execute model - vi = last(evaluate!!(model, vi, context)) + vi = last(evaluate!!(model, vi)) end logps = getacc(vi, Val(accumulator_name(AccType))).logps @@ -242,55 +238,46 @@ function pointwise_logdensities( end function pointwise_logdensities( - model::Model, - varinfo::AbstractVarInfo, - context::AbstractContext=DefaultContext(), - ::Val{whichlogprob}=Val(:both), + model::Model, varinfo::AbstractVarInfo, ::Val{whichlogprob}=Val(:both) ) where {whichlogprob} AccType = PointwiseLogProbAccumulator{whichlogprob} varinfo = setaccs!!(varinfo, (AccType(),)) - varinfo = last(evaluate!!(model, varinfo, context)) + varinfo = last(evaluate!!(model, varinfo)) return getacc(varinfo, Val(accumulator_name(AccType))).logps end """ - pointwise_loglikelihoods(model, chain[, keytype, context]) + pointwise_loglikelihoods(model, chain[, keytype]) Compute the pointwise log-likelihoods of the model given the chain. -This is the same as `pointwise_logdensities(model, chain, context)`, but only +This is the same as `pointwise_logdensities(model, chain)`, but only including the likelihood terms. See also: [`pointwise_logdensities`](@ref), [`pointwise_prior_logdensities`](@ref). """ -function pointwise_loglikelihoods( - model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext() -) where {T} - return pointwise_logdensities(model, chain, T, context, Val(:likelihood)) +function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T} + return pointwise_logdensities(model, chain, T, Val(:likelihood)) end -function pointwise_loglikelihoods( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext() -) - return pointwise_logdensities(model, varinfo, context, Val(:likelihood)) +function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) + return pointwise_logdensities(model, varinfo, Val(:likelihood)) end """ - pointwise_prior_logdensities(model, chain[, keytype, context]) + pointwise_prior_logdensities(model, chain[, keytype]) Compute the pointwise log-prior-densities of the model given the chain. -This is the same as `pointwise_logdensities(model, chain, context)`, but only +This is the same as `pointwise_logdensities(model, chain)`, but only including the prior terms. See also: [`pointwise_logdensities`](@ref), [`pointwise_loglikelihoods`](@ref). """ function pointwise_prior_logdensities( - model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext() + model::Model, chain, keytype::Type{T}=String ) where {T} - return pointwise_logdensities(model, chain, T, context, Val(:prior)) + return pointwise_logdensities(model, chain, T, Val(:prior)) end -function pointwise_prior_logdensities( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext() -) - return pointwise_logdensities(model, varinfo, context, Val(:prior)) +function pointwise_prior_logdensities(model::Model, varinfo::AbstractVarInfo) + return pointwise_logdensities(model, varinfo, Val(:prior)) end diff --git a/src/sampler.jl b/src/sampler.jl index 49d910fec..589c56bd3 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -58,12 +58,12 @@ function AbstractMCMC.step( kwargs..., ) vi = VarInfo() - model(rng, vi, sampler) + DynamicPPL.sample!!(rng, model, vi, sampler) return vi, nothing end """ - default_varinfo(rng, model, sampler[, context]) + default_varinfo(rng, model, sampler) Return a default varinfo object for the given `model` and `sampler`. @@ -71,22 +71,13 @@ Return a default varinfo object for the given `model` and `sampler`. - `rng::Random.AbstractRNG`: Random number generator. - `model::Model`: Model for which we want to create a varinfo object. - `sampler::AbstractSampler`: Sampler which will make use of the varinfo object. -- `context::AbstractContext`: Context in which the model is evaluated. # Returns - `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`. """ function default_varinfo(rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler) - return default_varinfo(rng, model, sampler, DefaultContext()) -end -function default_varinfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler, - context::AbstractContext, -) init_sampler = initialsampler(sampler) - return typed_varinfo(rng, model, init_sampler, context) + return typed_varinfo(rng, model, init_sampler) end function AbstractMCMC.sample( @@ -119,7 +110,7 @@ function AbstractMCMC.step( # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 # and https://github.com/TuringLang/Turing.jl/issues/1563 # to avoid that existing variables are resampled - vi = last(evaluate!!(model, vi, DefaultContext())) + vi = last(evaluate!!(model, vi)) end return initialstep(rng, model, spl, vi; initial_params, kwargs...) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 42fcedfb8..bf83235cb 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -36,13 +36,10 @@ julia> m = demo(); julia> rng = StableRNG(42); -julia> ### Sampling ### - ctx = SamplingContext(rng, SampleFromPrior(), DefaultContext()); - julia> # In the `NamedTuple` version we need to provide the place-holder values for # the variables which are using "containers", e.g. `Array`. # In this case, this means that we need to specify `x` but not `m`. - _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo((x = ones(2), )), ctx); + _, vi = DynamicPPL.sample!!(rng, m, SimpleVarInfo((x = ones(2), ))); julia> # (✓) Vroom, vroom! FAST!!! vi[@varname(x[1])] @@ -60,12 +57,12 @@ julia> vi[@varname(x[1:2])] 1.3736306979834252 julia> # (×) If we don't provide the container... - _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo(), ctx); vi + _, vi = DynamicPPL.sample!!(rng, m, SimpleVarInfo()); vi ERROR: type NamedTuple has no field x [...] julia> # If one does not know the varnames, we can use a `OrderedDict` instead. - _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo{Float64}(OrderedDict()), ctx); + _, vi = DynamicPPL.sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict())); julia> # (✓) Sort of fast, but only possible at runtime. vi[@varname(x[1])] @@ -94,28 +91,28 @@ demo_constrained (generic function with 2 methods) julia> m = demo_constrained(); -julia> _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo(), ctx); +julia> _, vi = DynamicPPL.sample!!(rng, m, SimpleVarInfo()); julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞ 1.8632965762164932 -julia> _, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(), true), ctx); +julia> _, vi = DynamicPPL.sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ -0.21080155351918753 -julia> xs = [last(DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(), true), ctx))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true julia> # And with `OrderedDict` of course! - _, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict()), true), ctx); + _, vi = DynamicPPL.sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict()), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ 0.6225185067787314 -julia> xs = [last(DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(), true), ctx))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true @@ -128,7 +125,7 @@ julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true) Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0))) julia> # (✓) Positive probability mass on negative numbers! - getlogjoint(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) + getlogjoint(last(DynamicPPL.evaluate!!(m, vi))) -1.3678794411714423 julia> # While if we forget to indicate that it's transformed: @@ -136,7 +133,7 @@ julia> # While if we forget to indicate that it's transformed: SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0))) julia> # (✓) No probability mass on negative numbers! - getlogjoint(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) + getlogjoint(last(DynamicPPL.evaluate!!(m, vi))) -Inf ``` @@ -228,15 +225,25 @@ function SimpleVarInfo(; kwargs...) end # Constructor from `Model`. -function SimpleVarInfo( - model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}... -) - return SimpleVarInfo{LogProbType}(model, args...) +function SimpleVarInfo{T}( + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() +) where {T<:Real} + new_model = contextualize(model, SamplingContext(rng, sampler, model.context)) + return last(evaluate!!(new_model, SimpleVarInfo{T}())) end function SimpleVarInfo{T}( - model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}... + model::Model, sampler::AbstractSampler=SampleFromPrior() ) where {T<:Real} - return last(evaluate!!(model, SimpleVarInfo{T}(), args...)) + return SimpleVarInfo{T}(Random.default_rng(), model, sampler) +end +# Constructors without type param +function SimpleVarInfo( + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() +) + return SimpleVarInfo{LogProbType}(rng, model, sampler) +end +function SimpleVarInfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) + return SimpleVarInfo{LogProbType}(Random.default_rng(), model, sampler) end # Constructor from `VarInfo`. @@ -252,12 +259,12 @@ end function untyped_simple_varinfo(model::Model) varinfo = SimpleVarInfo(OrderedDict()) - return last(evaluate!!(model, varinfo, SamplingContext())) + return last(sample!!(model, varinfo)) end function typed_simple_varinfo(model::Model) varinfo = SimpleVarInfo{Float64}() - return last(evaluate!!(model, varinfo, SamplingContext())) + return last(sample!!(model, varinfo)) end function unflatten(svi::SimpleVarInfo, x::AbstractVector) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index bd08b427e..67c3a8c18 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -207,41 +207,41 @@ ERROR: LoadError: cannot automatically prefix with no left-hand side resolved at runtime. """ macro submodel(prefix_expr, expr) - return submodel(prefix_expr, expr, esc(:__context__)) + return submodel(prefix_expr, expr, esc(:__model__)) end # Automatic prefixing. -function prefix_submodel_context(prefix::Bool, left::Symbol, ctx) - return prefix ? prefix_submodel_context(left, ctx) : ctx +function prefix_submodel_context(prefix::Bool, left::Symbol, model) + return prefix ? prefix_submodel_context(left, model) : :($model.context) end -function prefix_submodel_context(prefix::Bool, left::Expr, ctx) - return prefix ? prefix_submodel_context(varname(left), ctx) : ctx +function prefix_submodel_context(prefix::Bool, left::Expr, model) + return prefix ? prefix_submodel_context(varname(left), model) : :($model.context) end # Manual prefixing. -prefix_submodel_context(prefix, left, ctx) = prefix_submodel_context(prefix, ctx) -function prefix_submodel_context(prefix, ctx) +prefix_submodel_context(prefix, left, model) = prefix_submodel_context(prefix, model) +function prefix_submodel_context(prefix, model) # E.g. `prefix="asd[$i]"` or `prefix=asd` with `asd` to be evaluated. - return :($(PrefixContext)($(Val)($(Symbol)($(esc(prefix)))), $ctx)) + return :($(PrefixContext)($(Val)($(Symbol)($(esc(prefix)))), $model.context)) end -function prefix_submodel_context(prefix::Union{AbstractString,Symbol}, ctx) +function prefix_submodel_context(prefix::Union{AbstractString,Symbol}, model) # E.g. `prefix="asd"`. - return :($(PrefixContext)($(esc(Meta.quot(Val(Symbol(prefix))))), $ctx)) + return :($(PrefixContext)($(esc(Meta.quot(Val(Symbol(prefix))))), $model.context)) end -function prefix_submodel_context(prefix::Bool, ctx) +function prefix_submodel_context(prefix::Bool, model) if prefix error("cannot automatically prefix with no left-hand side") end - return ctx + return :($model.context) end const SUBMODEL_DEPWARN_MSG = "`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax." -function submodel(prefix_expr, expr, ctx=esc(:__context__)) +function submodel(prefix_expr, expr, model=esc(:__model__)) prefix_left, prefix = getargs_assignment(prefix_expr) if prefix_left !== :prefix error("$(prefix_left) is not a valid kwarg") @@ -257,7 +257,7 @@ function submodel(prefix_expr, expr, ctx=esc(:__context__)) # `prefix=...` => use it. args_assign = getargs_assignment(expr) return if args_assign === nothing - ctx = prefix_submodel_context(prefix, ctx) + ctx = prefix_submodel_context(prefix, model) quote # Raise deprecation warning to let user know that we recommend using `left ~ to_submodel(model)`. $(Base.depwarn)(SUBMODEL_DEPWARN_MSG, Symbol("@submodel")) @@ -271,7 +271,7 @@ function submodel(prefix_expr, expr, ctx=esc(:__context__)) L, R = args_assign # Now that we have `L` and `R`, we can prefix automagically. try - ctx = prefix_submodel_context(prefix, L, ctx) + ctx = prefix_submodel_context(prefix, L, model) catch e error( "failed to determine prefix from $(L); please specify prefix using the `@submodel prefix=\"your prefix\" ...` syntax", diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index 0c267c1c5..5285391b1 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -60,8 +60,6 @@ struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat} model::Model "The VarInfo that was used" varinfo::AbstractVarInfo - "The evaluation context that was used" - context::AbstractContext "The values at which the model was evaluated" params::Vector{Tparams} "The AD backend that was tested" @@ -92,7 +90,6 @@ end grad_atol=1e-6, varinfo::AbstractVarInfo=link(VarInfo(model), model), params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, - context::AbstractContext=DefaultContext(), reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing, verbose=true, @@ -146,13 +143,7 @@ Everything else is optional, and can be categorised into several groups: prep_params)`. You could then evaluate the gradient at a different set of parameters using the `params` keyword argument. -3. _How to specify the evaluation context._ - - A `DynamicPPL.AbstractContext` can be passed as the `context` keyword - argument to control the evaluation context. This defaults to - `DefaultContext()`. - -4. _How to specify the results to compare against._ (Only if `test=true`.) +3. _How to specify the results to compare against._ (Only if `test=true`.) Once logp and its gradient has been calculated with the specified `adtype`, it must be tested for correctness. @@ -167,12 +158,12 @@ Everything else is optional, and can be categorised into several groups: The default reference backend is ForwardDiff. If none of these parameters are specified, ForwardDiff will be used to calculate the ground truth. -5. _How to specify the tolerances._ (Only if `test=true`.) +4. _How to specify the tolerances._ (Only if `test=true`.) The tolerances for the value and gradient can be set using `value_atol` and `grad_atol`. These default to 1e-6. -6. _Whether to output extra logging information._ +5. _Whether to output extra logging information._ By default, this function prints messages when it runs. To silence it, set `verbose=false`. @@ -195,7 +186,6 @@ function run_ad( grad_atol::AbstractFloat=1e-6, varinfo::AbstractVarInfo=link(VarInfo(model), model), params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, - context::AbstractContext=DefaultContext(), reference_adtype::AbstractADType=REFERENCE_ADTYPE, expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing, verbose=true, @@ -207,7 +197,7 @@ function run_ad( verbose && @info "Running AD on $(model.f) with $(adtype)\n" verbose && println(" params : $(params)") - ldf = LogDensityFunction(model, varinfo, context; adtype=adtype) + ldf = LogDensityFunction(model, varinfo; adtype=adtype) value, grad = logdensity_and_gradient(ldf, params) grad = collect(grad) @@ -216,7 +206,7 @@ function run_ad( if test # Calculate ground truth to compare against value_true, grad_true = if expected_value_and_grad === nothing - ldf_reference = LogDensityFunction(model, varinfo, context; adtype=reference_adtype) + ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype) logdensity_and_gradient(ldf_reference, params) else expected_value_and_grad @@ -245,7 +235,6 @@ function run_ad( return ADResult( model, varinfo, - context, params, adtype, value_atol, diff --git a/src/test_utils/model_interface.jl b/src/test_utils/model_interface.jl index ce79f2302..1a2ec6ebe 100644 --- a/src/test_utils/model_interface.jl +++ b/src/test_utils/model_interface.jl @@ -92,9 +92,8 @@ Even though it is recommended to implement this by hand for a particular `Model` a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. """ function varnames(model::Model) - return collect( - keys(last(DynamicPPL.evaluate!!(model, SimpleVarInfo(Dict()), SamplingContext()))) - ) + new_model = contextualize(model, SamplingContext(model.context)) + return collect(keys(last(DynamicPPL.evaluate!!(new_model, SimpleVarInfo(Dict()))))) end """ diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 07a308c7a..542fc17fc 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -48,7 +48,7 @@ function setup_varinfos( )) do vi # Set them all to the same values and evaluate logp. vi = update_values!!(vi, example_values, varnames) - last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) + last(DynamicPPL.evaluate!!(model, vi)) end if include_threadsafe diff --git a/src/threadsafe.jl b/src/threadsafe.jl index cc07d70bb..51c57651d 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -116,14 +116,17 @@ end # consistency between `vi.accs_by_thread` field and `getacc(vi.varinfo)`, which accumulates # to define `getacc(vi)`. function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) + model = contextualize( + model, setleafcontext(model.context, DynamicTransformationContext{false}()) + ) + return settrans!!(last(evaluate!!(model, vi)), t) end function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - return settrans!!( - last(evaluate!!(model, vi, DynamicTransformationContext{true}())), - NoTransformation(), + model = contextualize( + model, setleafcontext(model.context, DynamicTransformationContext{true}()) ) + return settrans!!(last(evaluate!!(model, vi)), NoTransformation()) end function link(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) diff --git a/src/transforming.jl b/src/transforming.jl index ddd1ab59f..e3da0ff29 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -51,16 +51,17 @@ function _transform!!( vi::AbstractVarInfo, model::Model, ) - # To transform using DynamicTransformationContext, we evaluate the model, but we do not - # need to use any accumulators other than LogPriorAccumulator (which is affected by the Jacobian of - # the transformation). + # To transform using DynamicTransformationContext, we evaluate the model using that as the leaf context: + model = contextualize(model, setleafcontext(model.context, ctx)) + # but we do not need to use any accumulators other than LogPriorAccumulator + # (which is affected by the Jacobian of the transformation). accs = getaccs(vi) has_logprior = haskey(accs, Val(:LogPrior)) if has_logprior old_logprior = getacc(accs, Val(:LogPrior)) vi = setaccs!!(vi, (old_logprior,)) end - vi = settrans!!(last(evaluate!!(model, vi, ctx)), t) + vi = settrans!!(last(evaluate!!(model, vi)), t) # Restore the accumulators. if has_logprior new_logprior = getacc(vi, Val(:LogPrior)) diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 4d6225c10..4922ddbb0 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -52,7 +52,7 @@ end accumulate_observe!!(acc::ValuesAsInModelAccumulator, right, left, vn) = acc """ - values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo[, context::AbstractContext]) + values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo) Get the values of `varinfo` as they would be seen in the model. @@ -69,8 +69,6 @@ space at the cost of additional model evaluations. - `model::Model`: model to extract realizations from. - `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`. - `varinfo::AbstractVarInfo`: variable information to use for the extraction. -- `context::AbstractContext`: evaluation context to use in the extraction. Defaults - to `DynamicPPL.DefaultContext()`. # Examples @@ -124,14 +122,8 @@ julia> # Approach 2: Extract realizations using `values_as_in_model`. true ``` """ -function values_as_in_model( - model::Model, - include_colon_eq::Bool, - varinfo::AbstractVarInfo, - context::AbstractContext=DefaultContext(), -) - accs = getaccs(varinfo) +function values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo) varinfo = setaccs!!(deepcopy(varinfo), (ValuesAsInModelAccumulator(include_colon_eq),)) - varinfo = last(evaluate!!(model, varinfo, context)) + varinfo = last(evaluate!!(model, varinfo)) return getacc(varinfo, Val(:ValuesAsInModel)).values end diff --git a/src/varinfo.jl b/src/varinfo.jl index 20986d1a4..8cd14b134 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -106,10 +106,10 @@ function VarInfo(meta=Metadata()) end """ - VarInfo([rng, ]model[, sampler, context]) + VarInfo([rng, ]model[, sampler]) Generate a `VarInfo` object for the given `model`, by evaluating it once using -the given `rng`, `sampler`, and `context`. +the given `rng`, `sampler`. !!! warning @@ -122,28 +122,12 @@ the given `rng`, `sampler`, and `context`. instead. """ function VarInfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - return typed_varinfo(rng, model, sampler, context) -end -function VarInfo( - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() ) - # No rng - return VarInfo(Random.default_rng(), model, sampler, context) + return typed_varinfo(rng, model, sampler) end -function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) - # No sampler - return VarInfo(rng, model, SampleFromPrior(), context) -end -function VarInfo(model::Model, context::AbstractContext) - # No sampler, no rng - return VarInfo(Random.default_rng(), model, SampleFromPrior(), context) +function VarInfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) + return VarInfo(Random.default_rng(), model, sampler) end const UntypedVectorVarInfo = VarInfo{<:VarNamedVector} @@ -200,7 +184,7 @@ end ######################## """ - untyped_varinfo([rng, ]model[, sampler, context, metadata]) + untyped_varinfo([rng, ]model[, sampler]) Return a VarInfo object for the given `model` and `context`, which has just a single `Metadata` as its metadata field. @@ -209,33 +193,16 @@ single `Metadata` as its metadata field. - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object - `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. -- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. """ function untyped_varinfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() ) varinfo = VarInfo(Metadata()) - context = SamplingContext(rng, sampler, context) - return last(evaluate!!(model, varinfo, context)) -end -function untyped_varinfo( - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - # No rng - return untyped_varinfo(Random.default_rng(), model, sampler, context) -end -function untyped_varinfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) - # No sampler - return untyped_varinfo(rng, model, SampleFromPrior(), context) + new_model = contextualize(model, SamplingContext(rng, sampler, model.context)) + return last(evaluate!!(new_model, varinfo)) end -function untyped_varinfo(model::Model, context::AbstractContext) - # No sampler, no rng - return untyped_varinfo(model, SampleFromPrior(), context) +function untyped_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) + return untyped_varinfo(Random.default_rng(), model, sampler) end """ @@ -298,87 +265,51 @@ function typed_varinfo(vi::NTVarInfo) return vi end """ - typed_varinfo([rng, ]model[, sampler, context, metadata]) + typed_varinfo([rng, ]model[, sampler]) -Return a VarInfo object for the given `model` and `context`, which has a NamedTuple of +Return a VarInfo object for the given `model`, which has a NamedTuple of `Metadata` structs as its metadata field. # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object - `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. -- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. """ function typed_varinfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() ) - return typed_varinfo(untyped_varinfo(rng, model, sampler, context)) -end -function typed_varinfo( - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - # No rng - return typed_varinfo(Random.default_rng(), model, sampler, context) -end -function typed_varinfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) - # No sampler - return typed_varinfo(rng, model, SampleFromPrior(), context) + return typed_varinfo(untyped_varinfo(rng, model, sampler)) end -function typed_varinfo(model::Model, context::AbstractContext) - # No sampler, no rng - return typed_varinfo(model, SampleFromPrior(), context) +function typed_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) + return typed_varinfo(Random.default_rng(), model, sampler) end """ - untyped_vector_varinfo([rng, ]model[, sampler, context, metadata]) + untyped_vector_varinfo([rng, ]model[, sampler]) -Return a VarInfo object for the given `model` and `context`, which has just a -single `VarNamedVector` as its metadata field. +Return a VarInfo object for the given `model`, which has just a single +`VarNamedVector` as its metadata field. # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object - `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. -- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. """ function untyped_vector_varinfo(vi::UntypedVarInfo) md = metadata_to_varnamedvector(vi.metadata) return VarInfo(md, deepcopy(vi.accs)) end function untyped_vector_varinfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - return untyped_vector_varinfo(untyped_varinfo(rng, model, sampler, context)) -end -function untyped_vector_varinfo( - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - # No rng - return untyped_vector_varinfo(Random.default_rng(), model, sampler, context) -end -function untyped_vector_varinfo( - rng::Random.AbstractRNG, model::Model, context::AbstractContext + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() ) - # No sampler - return untyped_vector_varinfo(rng, model, SampleFromPrior(), context) + return untyped_vector_varinfo(untyped_varinfo(rng, model, sampler)) end -function untyped_vector_varinfo(model::Model, context::AbstractContext) - # No sampler, no rng - return untyped_vector_varinfo(model, SampleFromPrior(), context) +function untyped_vector_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) + return untyped_vector_varinfo(Random.default_rng(), model, sampler) end """ - typed_vector_varinfo([rng, ]model[, sampler, context, metadata]) + typed_vector_varinfo([rng, ]model[, sampler]) Return a VarInfo object for the given `model` and `context`, which has a NamedTuple of `VarNamedVector`s as its metadata field. @@ -387,7 +318,6 @@ NamedTuple of `VarNamedVector`s as its metadata field. - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object - `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. -- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. """ function typed_vector_varinfo(vi::NTVarInfo) md = map(metadata_to_varnamedvector, vi.metadata) @@ -399,30 +329,12 @@ function typed_vector_varinfo(vi::UntypedVectorVarInfo) return VarInfo(nt, deepcopy(vi.accs)) end function typed_vector_varinfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - return typed_vector_varinfo(untyped_vector_varinfo(rng, model, sampler, context)) -end -function typed_vector_varinfo( - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - # No rng - return typed_vector_varinfo(Random.default_rng(), model, sampler, context) -end -function typed_vector_varinfo( - rng::Random.AbstractRNG, model::Model, context::AbstractContext + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() ) - # No sampler - return typed_vector_varinfo(rng, model, SampleFromPrior(), context) + return typed_vector_varinfo(untyped_vector_varinfo(rng, model, sampler)) end -function typed_vector_varinfo(model::Model, context::AbstractContext) - # No sampler, no rng - return typed_vector_varinfo(model, SampleFromPrior(), context) +function typed_vector_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) + return typed_vector_varinfo(Random.default_rng(), model, sampler) end """ diff --git a/test/ad.jl b/test/ad.jl index c34624f5b..0947c017a 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -110,9 +110,8 @@ using DynamicPPL: LogDensityFunction # Compiling the ReverseDiff tape used to fail here spl = Sampler(MyEmptyAlg()) vi = VarInfo(model) - ldf = LogDensityFunction( - model, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true) - ) + sampling_model = contextualize(model, SamplingContext(model.context)) + ldf = LogDensityFunction(sampling_model, vi; adtype=AutoReverseDiff(; compile=true)) @test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any end diff --git a/test/compiler.jl b/test/compiler.jl index 2e76de27f..de22d1b67 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -185,10 +185,7 @@ module Issue537 end @model function testmodel_missing3(x) x[1] ~ Bernoulli(0.5) global varinfo_ = __varinfo__ - global sampler_ = __context__.sampler global model_ = __model__ - global context_ = __context__ - global rng_ = __context__.rng global lp = getlogjoint(__varinfo__) return x end @@ -196,18 +193,18 @@ module Issue537 end varinfo = VarInfo(model) @test getlogjoint(varinfo) == lp @test varinfo_ isa AbstractVarInfo - @test model_ === model - @test context_ isa SamplingContext - @test rng_ isa Random.AbstractRNG + # During the model evaluation, its context is wrapped in a + # SamplingContext, so `model_` is not going to be equal to `model`. + # We can still check equality of `f` though. + @test model_.f === model.f + @test model_.context isa SamplingContext + @test model_.context.rng isa Random.AbstractRNG # disable warnings @model function testmodel_missing4(x) x[1] ~ Bernoulli(0.5) global varinfo_ = __varinfo__ - global sampler_ = __context__.sampler global model_ = __model__ - global context_ = __context__ - global rng_ = __context__.rng global lp = getlogjoint(__varinfo__) return x end false @@ -601,13 +598,13 @@ module Issue537 end # an attempt at a `NamedTuple` of the form `(x = 1, __varinfo__)`. @model empty_model() = return x = 1 empty_vi = VarInfo() - retval_and_vi = DynamicPPL.evaluate!!(empty_model(), empty_vi, SamplingContext()) + retval_and_vi = DynamicPPL.sample!!(empty_model(), empty_vi) @test retval_and_vi isa Tuple{Int,typeof(empty_vi)} # Even if the return-value is `AbstractVarInfo`, we should return # a `Tuple` with `AbstractVarInfo` in the second component too. @model demo() = return __varinfo__ - retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext()) + retval, svi = DynamicPPL.sample!!(demo(), SimpleVarInfo()) @test svi == SimpleVarInfo() if Threads.nthreads() > 1 @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} @@ -623,11 +620,11 @@ module Issue537 end f(x) = return x^2 return f(1.0) end - retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext()) + retval, svi = DynamicPPL.sample!!(demo(), SimpleVarInfo()) @test retval isa Float64 @model demo() = x ~ Normal() - retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext()) + retval, svi = DynamicPPL.sample!!(demo(), SimpleVarInfo()) # Return-value when using `to_submodel` @model inner() = x ~ Normal() diff --git a/test/context_implementations.jl b/test/context_implementations.jl index ac6321d69..e16b2dc96 100644 --- a/test/context_implementations.jl +++ b/test/context_implementations.jl @@ -5,12 +5,12 @@ μ ~ MvNormal(zeros(2), 4 * I) z = Vector{Int}(undef, length(x)) z ~ product_distribution(Categorical.(fill([0.5, 0.5], length(x)))) - for i in 1:length(x) + for i in eachindex(x) x[i] ~ Normal(μ[z[i]], 0.1) end end - test([1, 1, -1])(VarInfo(), SampleFromPrior(), DefaultContext()) + test([1, 1, -1])(VarInfo()) end @testset "dot tilde with varying sizes" begin diff --git a/test/contexts.jl b/test/contexts.jl index 1dd6a2280..597ab736c 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -184,9 +184,10 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS prefix_vn = @varname(my_prefix) context = DynamicPPL.PrefixContext(prefix_vn, SamplingContext()) + sampling_model = contextualize(model, context) # Sample with the context. varinfo = DynamicPPL.VarInfo() - DynamicPPL.evaluate!!(model, varinfo, context) + DynamicPPL.evaluate!!(sampling_model, varinfo) # Extract the resulting varnames vns_actual = Set(keys(varinfo)) diff --git a/test/debug_utils.jl b/test/debug_utils.jl index d2269e089..8279ac51a 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -1,7 +1,7 @@ @testset "check_model" begin @testset "context interface" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - context = DynamicPPL.DebugUtils.DebugContext(model) + context = DynamicPPL.DebugUtils.DebugContext() DynamicPPL.TestUtils.test_context(context, model) end end @@ -35,9 +35,7 @@ buggy_model = buggy_demo_model() @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model( - buggy_model; context=SamplingContext(), record_varinfo=false - ) + issuccess = check_model(buggy_model; record_varinfo=false) @test !issuccess @test_throws ErrorException check_model(buggy_model; error_on_failure=true) end @@ -81,9 +79,7 @@ buggy_model = buggy_subsumes_demo_model() @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model( - buggy_model; context=SamplingContext(), record_varinfo=false - ) + issuccess = check_model(buggy_model; record_varinfo=false) @test !issuccess @test_throws ErrorException check_model(buggy_model; error_on_failure=true) end @@ -98,9 +94,7 @@ buggy_model = buggy_subsumes_demo_model() @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model( - buggy_model; context=SamplingContext(), record_varinfo=false - ) + issuccess = check_model(buggy_model; record_varinfo=false) @test !issuccess @test_throws ErrorException check_model(buggy_model; error_on_failure=true) end @@ -115,9 +109,7 @@ buggy_model = buggy_subsumes_demo_model() @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model( - buggy_model; context=SamplingContext(), record_varinfo=false - ) + issuccess = check_model(buggy_model; record_varinfo=false) @test !issuccess @test_throws ErrorException check_model(buggy_model; error_on_failure=true) end diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index 86329a51d..6737cf056 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -62,6 +62,7 @@ @testset "demo models" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + sampling_model = contextualize(model, SamplingContext(model.context)) # Use debug logging below. varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model) # Check that the inferred varinfo is indeed suitable for evaluation and sampling @@ -71,7 +72,7 @@ JET.test_call(f_eval, argtypes_eval) f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, varinfo, DynamicPPL.SamplingContext() + sampling_model, varinfo ) JET.test_call(f_sample, argtypes_sample) # For our demo models, they should all result in typed. @@ -85,7 +86,7 @@ ) JET.test_call(f_eval, argtypes_eval) f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, typed_vi, DynamicPPL.SamplingContext() + sampling_model, typed_vi ) JET.test_call(f_sample, argtypes_sample) end diff --git a/test/linking.jl b/test/linking.jl index 4f1707263..b0c2dcb5c 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -78,7 +78,7 @@ end vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(m),)) @testset "$(short_varinfo_name(vi))" for vi in vis # Evaluate once to ensure we have `logp` value. - vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) + vi = last(DynamicPPL.evaluate!!(model, vi)) vi_linked = if mutable DynamicPPL.link!!(deepcopy(vi), model) else diff --git a/test/model.jl b/test/model.jl index ea260a68c..574d276d6 100644 --- a/test/model.jl +++ b/test/model.jl @@ -162,12 +162,12 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() for i in 1:10 Random.seed!(100 + i) vi = VarInfo() - model(Random.default_rng(), vi, sampler) + DynamicPPL.sample!!(Random.default_rng(), model, vi, sampler) vals = vi[:] Random.seed!(100 + i) vi = VarInfo() - model(Random.default_rng(), vi, sampler) + DynamicPPL.sample!!(Random.default_rng(), model, vi, sampler) @test vi[:] == vals end end @@ -223,7 +223,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() # Second component of return-value of `evaluate!!` should # be a `DynamicPPL.AbstractVarInfo`. - evaluate_retval = DynamicPPL.evaluate!!(model, vi, DefaultContext()) + evaluate_retval = DynamicPPL.evaluate!!(model, vi) @test evaluate_retval[2] isa DynamicPPL.AbstractVarInfo # Should not return `AbstractVarInfo` when we call the model. @@ -332,11 +332,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test logjoint(model, x) != DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...) # Ensure `varnames` is implemented. - vi = last( - DynamicPPL.evaluate!!( - model, SimpleVarInfo(OrderedDict()), SamplingContext() - ), - ) + vi = last(DynamicPPL.sample!!(model, SimpleVarInfo(OrderedDict()))) @test all(collect(keys(vi)) .== DynamicPPL.TestUtils.varnames(model)) # Ensure `posterior_mean` is implemented. @test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x) @@ -397,7 +393,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() models_to_test = [ DynamicPPL.TestUtils.DEMO_MODELS..., DynamicPPL.TestUtils.demo_lkjchol(2) ] - context = DefaultContext() @testset "$(model.f)" for model in models_to_test vns = DynamicPPL.TestUtils.varnames(model) example_values = DynamicPPL.TestUtils.rand_prior_true(model) @@ -407,13 +402,13 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos @test begin - @inferred(DynamicPPL.evaluate!!(model, varinfo, context)) + @inferred(DynamicPPL.evaluate!!(model, varinfo)) true end varinfo_linked = DynamicPPL.link(varinfo, model) @test begin - @inferred(DynamicPPL.evaluate!!(model, varinfo_linked, context)) + @inferred(DynamicPPL.evaluate!!(model, varinfo_linked)) true end end @@ -492,7 +487,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos varinfo_linked = DynamicPPL.link(varinfo, model) varinfo_linked_result = last( - DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked), DefaultContext()) + DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked)) ) @test getlogjoint(varinfo_linked) ≈ getlogjoint(varinfo_linked_result) end @@ -596,7 +591,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() xs_train = 1:0.1:10 ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train)) m_lin_reg = linear_reg(xs_train, ys_train) - chain = [evaluate!!(m_lin_reg)[2] for _ in 1:10000] + chain = [last(DynamicPPL.sample!!(m_lin_reg, VarInfo())) for _ in 1:10000] # chain is generated from the prior @test mean([chain[i][@varname(β)] for i in eachindex(chain)]) ≈ 1.0 atol = 0.1 diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 6f2f39a64..2c9986dcc 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -98,7 +98,7 @@ for vn in DynamicPPL.TestUtils.varnames(model) vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) end - vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) + vi = last(DynamicPPL.evaluate!!(model, vi)) # `link!!` vi_linked = link!!(deepcopy(vi), model) @@ -158,7 +158,7 @@ ### Sampling ### # Sample a new varinfo! - _, svi_new = DynamicPPL.evaluate!!(model, svi, SamplingContext()) + _, svi_new = DynamicPPL.sample!!(model, svi) # Realization for `m` should be different wp. 1. for vn in DynamicPPL.TestUtils.varnames(model) @@ -226,9 +226,9 @@ # Initialize. svi_nt = DynamicPPL.settrans!!(SimpleVarInfo(), true) - svi_nt = last(DynamicPPL.evaluate!!(model, svi_nt, SamplingContext())) + svi_nt = last(DynamicPPL.sample!!(model, svi_nt)) svi_vnv = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - svi_vnv = last(DynamicPPL.evaluate!!(model, svi_vnv, SamplingContext())) + svi_vnv = last(DynamicPPL.sample!!(model, svi_vnv)) for svi in (svi_nt, svi_vnv) # Sample with large variations in unconstrained space. @@ -236,7 +236,7 @@ for vn in keys(svi) svi = DynamicPPL.setindex!!(svi, 10 * randn(), vn) end - retval, svi = DynamicPPL.evaluate!!(model, svi, DefaultContext()) + retval, svi = DynamicPPL.evaluate!!(model, svi) @test retval.m == svi[@varname(m)] # `m` is unconstrained @test retval.x ≠ svi[@varname(x)] # `x` is constrained depending on `m` @@ -273,7 +273,7 @@ ) # Resulting varinfo should no longer be transformed. - vi_result = last(DynamicPPL.evaluate!!(model, deepcopy(vi), SamplingContext())) + vi_result = last(DynamicPPL.sample!!(model, deepcopy(vi))) @test !DynamicPPL.istrans(vi_result) # Set the values to something that is out of domain if we're in constrained space. @@ -281,9 +281,7 @@ vi_linked = DynamicPPL.setindex!!(vi_linked, -rand(), vn) end - retval, vi_linked_result = DynamicPPL.evaluate!!( - model, deepcopy(vi_linked), DefaultContext() - ) + retval, vi_linked_result = DynamicPPL.evaluate!!(model, deepcopy(vi_linked)) @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≠ DynamicPPL.tovec(retval.s) # `s` is unconstrained in original diff --git a/test/threadsafe.jl b/test/threadsafe.jl index c673c8b36..24a738a78 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -52,9 +52,10 @@ x[i] ~ Normal(x[i - 1], 1) end end + model = wthreads(x) vi = VarInfo() - wthreads(x)(vi) + model(vi) lp_w_threads = getlogjoint(vi) if Threads.nthreads() == 1 @test vi_ isa VarInfo @@ -64,23 +65,19 @@ println("With `@threads`:") println(" default:") - @time wthreads(x)(vi) + @time model(vi) # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. - DynamicPPL.evaluate_threadsafe!!( - wthreads(x), - vi, - SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), - ) + sampling_model = contextualize(model, SamplingContext(model.context)) + DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) @test getlogjoint(vi) ≈ lp_w_threads + # check that it's wrapped during the model evaluation @test vi_ isa DynamicPPL.ThreadSafeVarInfo + # ensure that it's unwrapped after evaluation finishes + @test vi isa VarInfo println(" evaluate_threadsafe!!:") - @time DynamicPPL.evaluate_threadsafe!!( - wthreads(x), - vi, - SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), - ) + @time DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) @model function wothreads(x) global vi_ = __varinfo__ @@ -89,9 +86,10 @@ x[i] ~ Normal(x[i - 1], 1) end end + model = wothreads(x) vi = VarInfo() - wothreads(x)(vi) + model(vi) lp_wo_threads = getlogjoint(vi) if Threads.nthreads() == 1 @test vi_ isa VarInfo @@ -101,24 +99,18 @@ println("Without `@threads`:") println(" default:") - @time wothreads(x)(vi) + @time model(vi) @test lp_w_threads ≈ lp_wo_threads # Ensure that we use `VarInfo`. - DynamicPPL.evaluate_threadunsafe!!( - wothreads(x), - vi, - SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), - ) + sampling_model = contextualize(model, SamplingContext(model.context)) + DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) @test getlogjoint(vi) ≈ lp_w_threads @test vi_ isa VarInfo + @test vi isa VarInfo println(" evaluate_threadunsafe!!:") - @time DynamicPPL.evaluate_threadunsafe!!( - wothreads(x), - vi, - SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), - ) + @time DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) end end diff --git a/test/varinfo.jl b/test/varinfo.jl index 053fd3203..d0e63c17d 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -491,7 +491,7 @@ end # Check that instantiating the model does not perform linking vi = VarInfo() meta = vi.metadata - model(vi, SampleFromUniform()) + model(vi) @test all(x -> !istrans(vi, x), meta.vns) # Check that linking and invlinking set the `trans` flag accordingly @@ -565,7 +565,7 @@ end vi = DynamicPPL.untyped_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + vi = last(DynamicPPL.sample!!(model, vi)) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -574,7 +574,7 @@ end vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + vi = last(DynamicPPL.sample!!(model, vi)) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -583,7 +583,7 @@ end vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + vi = last(DynamicPPL.sample!!(model, vi)) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -592,7 +592,7 @@ end ## `SimpleVarInfo{<:NamedTuple}` vi = DynamicPPL.settrans!!(SimpleVarInfo(), true) # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + vi = last(DynamicPPL.sample!!(model, vi)) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -600,7 +600,7 @@ end ## `SimpleVarInfo{<:Dict}` vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true) # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + vi = last(DynamicPPL.sample!!(model, vi)) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -608,7 +608,7 @@ end ## `SimpleVarInfo{<:VarNamedVector}` vi = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + vi = last(DynamicPPL.sample!!(model, vi)) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -690,7 +690,7 @@ end end # Evaluate the model once to update the logp of the varinfo. - varinfo = last(DynamicPPL.evaluate!!(model, varinfo, DefaultContext())) + varinfo = last(DynamicPPL.evaluate!!(model, varinfo)) varinfo_linked = if mutating DynamicPPL.link!!(deepcopy(varinfo), model) @@ -993,9 +993,7 @@ end # Sampling from `model2` should hit the `istrans(vi) == true` branches # because all the existing variables are linked. model2 = demo(2) - varinfo2 = last( - DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext()) - ) + varinfo2 = last(DynamicPPL.sample!!(model2, deepcopy(varinfo1))) for vn in [@varname(x[1]), @varname(x[2])] @test DynamicPPL.istrans(varinfo2, vn) end @@ -1014,9 +1012,7 @@ end # Sampling from `model2` should hit the `istrans(vi) == true` branches # because all the existing variables are linked. model2 = demo_dot(2) - varinfo2 = last( - DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext()) - ) + varinfo2 = last(DynamicPPL.sample!!(model2, deepcopy(varinfo1))) for vn in [@varname(x), @varname(y[1])] @test DynamicPPL.istrans(varinfo2, vn) end diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index f21d458a8..33fb1a162 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -603,18 +603,14 @@ end DynamicPPL.TestUtils.test_values(varinfo, value_true, vns) # Is evaluation correct? - varinfo_eval = last( - DynamicPPL.evaluate!!(model, deepcopy(varinfo), DefaultContext()) - ) + varinfo_eval = last(DynamicPPL.evaluate!!(model, deepcopy(varinfo))) # Log density should be the same. @test getlogjoint(varinfo_eval) ≈ logp_true # Values should be the same. DynamicPPL.TestUtils.test_values(varinfo_eval, value_true, vns) # Is sampling correct? - varinfo_sample = last( - DynamicPPL.evaluate!!(model, deepcopy(varinfo), SamplingContext()) - ) + varinfo_sample = last(DynamicPPL.sample!!(model, deepcopy(varinfo))) # Log density should be different. @test getlogjoint(varinfo_sample) != getlogjoint(varinfo) # Values should be different. From 428b628a22da8391bcf7fd1fa1368d9b0dd57b17 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 14 Jun 2025 01:15:56 +0100 Subject: [PATCH 03/10] Fix some docstrings --- src/model.jl | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/model.jl b/src/model.jl index fd275dc17..403899760 100644 --- a/src/model.jl +++ b/src/model.jl @@ -856,17 +856,20 @@ end """ evaluate!!(model::Model, varinfo) - evaluate!!(model::Model, varinfo, context) -Evaluate the `model` with the given `varinfo`. If an extra context stack is -provided, the model's context is inserted into that context stack. See -`combine_model_and_external_contexts`. +Evaluate the `model` with the given `varinfo`. If multiple threads are available, the varinfo provided will be wrapped in a `ThreadSafeVarInfo` before evaluation. Returns a tuple of the model's return value, plus the updated `varinfo` (unwrapped if necessary). + + evaluate!!(model::Model, varinfo, context) + +When an extra context stack is provided, the model's context is inserted into +that context stack. See `combine_model_and_external_contexts`. This method is +deprecated. """ function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) return if use_threadsafe_eval(model.context, varinfo) @@ -924,13 +927,16 @@ end """ _evaluate!!(model::Model, varinfo) - _evaluate!!(model::Model, varinfo, context) -Evaluate the `model` with the given `varinfo`. If an additional `context` is provided, -the model's context is combined with that context. +Evaluate the `model` with the given `varinfo`. This function does not wrap the varinfo in a `ThreadSafeVarInfo`. It also does not reset the log probability of the `varinfo` before running. + + _evaluate!!(model::Model, varinfo, context) + +If an additional `context` is provided, the model's context is combined with +that context before evaluation. """ function _evaluate!!(model::Model, varinfo::AbstractVarInfo) args, kwargs = make_evaluate_args_and_kwargs(model, varinfo) @@ -975,7 +981,7 @@ function combine_model_and_external_contexts( end """ - make_evaluate_args_and_kwargs(model, varinfo, context) + make_evaluate_args_and_kwargs(model, varinfo) Return the arguments and keyword arguments to be passed to the evaluator of the model, i.e. `model.f`e. """ From 85805655a547387f7b9835bfb86180977d2f594d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 14 Jun 2025 01:20:56 +0100 Subject: [PATCH 04/10] fix ForwardDiffExt (look, multiple dispatch bad...) --- ext/DynamicPPLForwardDiffExt.jl | 1 - test/ext/DynamicPPLForwardDiffExt.jl | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/ext/DynamicPPLForwardDiffExt.jl b/ext/DynamicPPLForwardDiffExt.jl index 6bd7a5d94..7ea51918f 100644 --- a/ext/DynamicPPLForwardDiffExt.jl +++ b/ext/DynamicPPLForwardDiffExt.jl @@ -11,7 +11,6 @@ function DynamicPPL.tweak_adtype( ad::ADTypes.AutoForwardDiff{chunk_size}, ::DynamicPPL.Model, vi::DynamicPPL.AbstractVarInfo, - ::DynamicPPL.AbstractContext, ) where {chunk_size} params = vi[:] diff --git a/test/ext/DynamicPPLForwardDiffExt.jl b/test/ext/DynamicPPLForwardDiffExt.jl index 73a0510e9..44db66296 100644 --- a/test/ext/DynamicPPLForwardDiffExt.jl +++ b/test/ext/DynamicPPLForwardDiffExt.jl @@ -14,17 +14,16 @@ using Test: @test, @testset @model f() = x ~ MvNormal(zeros(MODEL_SIZE), I) model = f() varinfo = VarInfo(model) - context = DefaultContext() @testset "Chunk size setting" for chunksize in (nothing, 0) base_adtype = AutoForwardDiff(; chunksize=chunksize) - new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo, context) + new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo) @test new_adtype isa AutoForwardDiff{MODEL_SIZE} end @testset "Tag setting" begin base_adtype = AutoForwardDiff() - new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo, context) + new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo) @test new_adtype.tag isa ForwardDiff.Tag{DynamicPPL.DynamicPPLTag} end end From fe3a8d588ffd73996cacb4411b80753919d0cecd Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 14 Jun 2025 01:37:36 +0100 Subject: [PATCH 05/10] Changelog --- HISTORY.md | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index 7ab9ee1dc..80a0cc981 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -18,6 +18,41 @@ This release overhauls how VarInfo objects track variables such as the log joint - `getlogp` now returns a `NamedTuple` with keys `logprior` and `loglikelihood`. If you want the log joint probability, which is what `getlogp` used to return, use `getlogjoint`. - Correspondingly `setlogp!!` and `acclogp!!` should now be called with a `NamedTuple` with keys `logprior` and `loglikelihood`. The `acclogp!!` method with a single scalar value has been deprecated and falls back on `accloglikelihood!!`, and the single scalar version of `setlogp!!` has been removed. Corresponding setter/accumulator functions exist for the log prior as well. +### Evaluation contexts + +Historically, evaluating a DynamicPPL model has required three arguments: a model, some kind of VarInfo, and a context. +It's less known, though, that since DynamicPPL 0.14.0 the _model_ itself actually contains a context as well. +This version therefore excises the context argument, and instead uses `model.context` as the evaluation context. + +The upshot of this is that many functions that previously took a context argument now no longer do. +There were very few such functions where the context argument was actually used (most of them simply took `DefaultContext()` as the default value). + +`evaluate!!(model, varinfo, ext_context)` is deprecated, and broadly speaking you should replace calls to that with `new_model = contextualize(model, ext_context); evaluate!!(new_model, varinfo)`. +If the 'external context' `ext_context` is a parent context, then you should wrap `model.context` appropriately to ensure that its information content is not lost. +If, on the other hand, `ext_context` is a `DefaultContext`, then you can just drop the argument entirely. + +To aid with this process, `contextualize` is now exported from DynamicPPL. + +The main situation where one _did_ want to specify an additional evaluation context was when that context was a `SamplingContext`. +Doing this would allow you to run the model and sample fresh values, instead of just using the values that existed in the VarInfo object. +Thus, this release also introduces the unexported function `sample!!`. +Essentially, `sample!!(rng, model, varinfo, sampler)` is a drop-in replacement for `evaluate!!(model, varinfo, SamplingContext(rng, sampler))`. + +There are many methods that no longer take a context argument, and listing them all would be too much. +However, here are the more user-facing ones: + + - `LogDensityFunction` no longer has a context field (or type parameter) + - `DynamicPPL.TestUtils.AD.run_ad` no longer uses a context (and the returned `ADResult` object no longer has a context field) + - `VarInfo(rng, model, sampler)` and other VarInfo constructors / functions that made VarInfos (e.g. `typed_varinfo`) from a model + - `(::Model)(args...)`: specifically, this now only takes `rng` and `varinfo` arguments (with both being optional) + - If you are using the `__context__` special variable inside a model, you will now have to use `__model__.context` instead + +And a couple of more internal changes: + + - `evaluate!!`, `evaluate_threadsafe!!`, and `evaluate_threadunsafe!!` no longer accept context arguments + - `evaluate!!` no longer takes rng and sampler (if you used this, you should use `sample!!` instead, or construct your own `SamplingContext`) + - The model evaluation function, `model.f` for some `model::Model`, no longer takes a context as an argument + ## 0.36.12 Removed several unexported functions. From a7b300995528ecd751bc303509b873c805adf55a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 18 Jun 2025 21:49:02 +0100 Subject: [PATCH 06/10] fix a test --- test/varinfo.jl | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/test/varinfo.jl b/test/varinfo.jl index d0e63c17d..975ec9498 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -488,10 +488,17 @@ end end model = gdemo([1.0, 1.5], [2.0, 2.5]) - # Check that instantiating the model does not perform linking + # Check that instantiating the model using SampleFromUniform does not + # perform linking + # Note (penelopeysm): The purpose of using SampleFromUniform (SFU) + # specifically in this test is because SFU samples from the linked + # distribution i.e. in unconstrained space. However, it does this not + # by linking the varinfo but by transforming the distributions on the + # fly. That's why it's worth specifically checking that it can do this + # without having to change the VarInfo object. vi = VarInfo() meta = vi.metadata - model(vi) + _, vi = DynamicPPL.sample!!(model, vi, SampleFromUniform()) @test all(x -> !istrans(vi, x), meta.vns) # Check that linking and invlinking set the `trans` flag accordingly From ba3446146656eb13d0df06ecfc7763727bd86165 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 19 Jun 2025 15:58:15 +0100 Subject: [PATCH 07/10] Fix docstrings --- src/logdensityfunction.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index e489b46ba..e7565d137 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -28,9 +28,9 @@ A struct which contains a model, along with all the information necessary to: - and if `adtype` is provided, calculate the gradient of the log density at that point. -At its most basic level, a LogDensityFunction wraps the model together with its -the type of varinfo to be used. These must be known in order to calculate the -log density (using [`DynamicPPL.evaluate!!`](@ref)). +At its most basic level, a LogDensityFunction wraps the model together with the +type of varinfo to be used. These must be known in order to calculate the log +density (using [`DynamicPPL.evaluate!!`](@ref)). If the `adtype` keyword argument is provided, then this struct will also store the adtype along with other information for efficient calculation of the @@ -139,7 +139,7 @@ end adtype::Union{Nothing,ADTypes.AbstractADType} ) -Create a new LogDensityFunction using the model, and varinfo from the given +Create a new LogDensityFunction using the model and varinfo from the given `ldf` argument, but with the AD type set to `adtype`. To remove the AD type, pass `nothing` as the second argument. """ From b92280d3699d40f0e6c61991b476f07bc976d95f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 19 Jun 2025 15:59:05 +0100 Subject: [PATCH 08/10] use `sample!!` --- src/extract_priors.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/extract_priors.jl b/src/extract_priors.jl index ac25bbbf8..0d8a190de 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -116,8 +116,7 @@ function extract_priors(rng::Random.AbstractRNG, model::Model) # workaround for the fact that `order` is still hardcoded in VarInfo, and hence you # can't push new variables without knowing the num_produce. Remove this when possible. varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(), NumProduceAccumulator())) - new_model = contextualize(model, SamplingContext(rng, model.context)) - varinfo = last(evaluate!!(new_model, varinfo)) + varinfo = last(sample!!(rng, model, varinfo)) return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end From d8019e1c6b61fd9eac53fa26fc057593d98be34b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 19 Jun 2025 16:02:15 +0100 Subject: [PATCH 09/10] Fix a couple more cases --- src/test_utils/model_interface.jl | 3 +-- src/varinfo.jl | 12 +++++------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/test_utils/model_interface.jl b/src/test_utils/model_interface.jl index 1a2ec6ebe..8d7d49d05 100644 --- a/src/test_utils/model_interface.jl +++ b/src/test_utils/model_interface.jl @@ -92,8 +92,7 @@ Even though it is recommended to implement this by hand for a particular `Model` a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. """ function varnames(model::Model) - new_model = contextualize(model, SamplingContext(model.context)) - return collect(keys(last(DynamicPPL.evaluate!!(new_model, SimpleVarInfo(Dict()))))) + return collect(keys(last(DynamicPPL.sample!!(model, SimpleVarInfo(Dict()))))) end """ diff --git a/src/varinfo.jl b/src/varinfo.jl index 8cd14b134..4e6e61f66 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -186,8 +186,8 @@ end """ untyped_varinfo([rng, ]model[, sampler]) -Return a VarInfo object for the given `model` and `context`, which has just a -single `Metadata` as its metadata field. +Construct a VarInfo object for the given `model`, which has just a single +`Metadata` as its metadata field. # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation @@ -197,9 +197,7 @@ single `Metadata` as its metadata field. function untyped_varinfo( rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() ) - varinfo = VarInfo(Metadata()) - new_model = contextualize(model, SamplingContext(rng, sampler, model.context)) - return last(evaluate!!(new_model, varinfo)) + return last(sample!!(rng, model, VarInfo(Metadata()), sampler)) end function untyped_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) return untyped_varinfo(Random.default_rng(), model, sampler) @@ -311,8 +309,8 @@ end """ typed_vector_varinfo([rng, ]model[, sampler]) -Return a VarInfo object for the given `model` and `context`, which has a -NamedTuple of `VarNamedVector`s as its metadata field. +Return a VarInfo object for the given `model`, which has a NamedTuple of +`VarNamedVector`s as its metadata field. # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation From fa90d1c78b45f09e6e0a11d1bc25224998b520ac Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 19 Jun 2025 16:04:54 +0100 Subject: [PATCH 10/10] Globally rename `sample!!` -> `evaluate_and_sample!!`, add changelog warning --- HISTORY.md | 7 ++++--- docs/src/api.md | 2 +- ext/DynamicPPLMCMCChainsExt.jl | 2 +- src/extract_priors.jl | 2 +- src/model.jl | 12 ++++++------ src/sampler.jl | 2 +- src/simple_varinfo.jl | 20 ++++++++++---------- src/test_utils/model_interface.jl | 4 +++- src/varinfo.jl | 2 +- test/compiler.jl | 8 ++++---- test/model.jl | 11 +++++++---- test/simple_varinfo.jl | 8 ++++---- test/varinfo.jl | 18 +++++++++--------- test/varnamedvector.jl | 4 +++- 14 files changed, 55 insertions(+), 47 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 80a0cc981..9edac441f 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -35,8 +35,9 @@ To aid with this process, `contextualize` is now exported from DynamicPPL. The main situation where one _did_ want to specify an additional evaluation context was when that context was a `SamplingContext`. Doing this would allow you to run the model and sample fresh values, instead of just using the values that existed in the VarInfo object. -Thus, this release also introduces the unexported function `sample!!`. -Essentially, `sample!!(rng, model, varinfo, sampler)` is a drop-in replacement for `evaluate!!(model, varinfo, SamplingContext(rng, sampler))`. +Thus, this release also introduces the **unexported** function `evaluate_and_sample!!`. +Essentially, `evaluate_and_sample!!(rng, model, varinfo, sampler)` is a drop-in replacement for `evaluate!!(model, varinfo, SamplingContext(rng, sampler))`. +**Do note that this is an internal method**, and its name or semantics are liable to change in the future without warning. There are many methods that no longer take a context argument, and listing them all would be too much. However, here are the more user-facing ones: @@ -50,7 +51,7 @@ However, here are the more user-facing ones: And a couple of more internal changes: - `evaluate!!`, `evaluate_threadsafe!!`, and `evaluate_threadunsafe!!` no longer accept context arguments - - `evaluate!!` no longer takes rng and sampler (if you used this, you should use `sample!!` instead, or construct your own `SamplingContext`) + - `evaluate!!` no longer takes rng and sampler (if you used this, you should use `evaluate_and_sample!!` instead, or construct your own `SamplingContext`) - The model evaluation function, `model.f` for some `model::Model`, no longer takes a context as an argument ## 0.36.12 diff --git a/docs/src/api.md b/docs/src/api.md index b867e2e64..32b3d80a6 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -455,7 +455,7 @@ By default, it does not perform any actual sampling: it only evaluates the model To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method: ```@docs -DynamicPPL.sample!! +DynamicPPL.evaluate_and_sample!! ``` The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model. diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 5e1c75aa5..a29696720 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -115,7 +115,7 @@ function DynamicPPL.predict( iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) predictive_samples = map(iters) do (sample_idx, chain_idx) DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx) - varinfo = last(DynamicPPL.sample!!(rng, model, varinfo)) + varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo)) vals = DynamicPPL.values_as_in_model(model, false, varinfo) varname_vals = mapreduce( diff --git a/src/extract_priors.jl b/src/extract_priors.jl index 0d8a190de..bd6bdb2f2 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -116,7 +116,7 @@ function extract_priors(rng::Random.AbstractRNG, model::Model) # workaround for the fact that `order` is still hardcoded in VarInfo, and hence you # can't push new variables without knowing the num_produce. Remove this when possible. varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(), NumProduceAccumulator())) - varinfo = last(sample!!(rng, model, varinfo)) + varinfo = last(evaluate_and_sample!!(rng, model, varinfo)) return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end diff --git a/src/model.jl b/src/model.jl index 403899760..f46137ed1 100644 --- a/src/model.jl +++ b/src/model.jl @@ -815,7 +815,7 @@ end # ^ Weird Documenter.jl bug means that we have to write the two above separately # as it can only detect the `function`-less syntax. function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo()) - return first(sample!!(rng, model, varinfo)) + return first(evaluate_and_sample!!(rng, model, varinfo)) end """ @@ -829,7 +829,7 @@ function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) end """ - sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler]) + evaluate_and_sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler]) Evaluate the `model` with the given `varinfo`, but perform sampling during the evaluation using the given `sampler` by wrapping the model's context in a @@ -839,7 +839,7 @@ If `sampler` is not provided, defaults to [`SampleFromPrior`](@ref). Returns a tuple of the model's return value, plus the updated `varinfo` object. """ -function sample!!( +function evaluate_and_sample!!( rng::Random.AbstractRNG, model::Model, varinfo::AbstractVarInfo, @@ -848,10 +848,10 @@ function sample!!( sampling_model = contextualize(model, SamplingContext(rng, sampler, model.context)) return evaluate!!(sampling_model, varinfo) end -function sample!!( +function evaluate_and_sample!!( model::Model, varinfo::AbstractVarInfo, sampler::AbstractSampler=SampleFromPrior() ) - return sample!!(Random.default_rng(), model, varinfo, sampler) + return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler) end """ @@ -1038,7 +1038,7 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f) Generate a sample of type `T` from the prior distribution of the `model`. """ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} - x = last(sample!!(rng, model, SimpleVarInfo{Float64}(OrderedDict()))) + x = last(evaluate_and_sample!!(rng, model, SimpleVarInfo{Float64}(OrderedDict()))) return values_as(x, T) end diff --git a/src/sampler.jl b/src/sampler.jl index 589c56bd3..673b5128f 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -58,7 +58,7 @@ function AbstractMCMC.step( kwargs..., ) vi = VarInfo() - DynamicPPL.sample!!(rng, model, vi, sampler) + DynamicPPL.evaluate_and_sample!!(rng, model, vi, sampler) return vi, nothing end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index bf83235cb..ea371c7da 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -39,7 +39,7 @@ julia> rng = StableRNG(42); julia> # In the `NamedTuple` version we need to provide the place-holder values for # the variables which are using "containers", e.g. `Array`. # In this case, this means that we need to specify `x` but not `m`. - _, vi = DynamicPPL.sample!!(rng, m, SimpleVarInfo((x = ones(2), ))); + _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo((x = ones(2), ))); julia> # (✓) Vroom, vroom! FAST!!! vi[@varname(x[1])] @@ -57,12 +57,12 @@ julia> vi[@varname(x[1:2])] 1.3736306979834252 julia> # (×) If we don't provide the container... - _, vi = DynamicPPL.sample!!(rng, m, SimpleVarInfo()); vi + _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo()); vi ERROR: type NamedTuple has no field x [...] julia> # If one does not know the varnames, we can use a `OrderedDict` instead. - _, vi = DynamicPPL.sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict())); + _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict())); julia> # (✓) Sort of fast, but only possible at runtime. vi[@varname(x[1])] @@ -91,28 +91,28 @@ demo_constrained (generic function with 2 methods) julia> m = demo_constrained(); -julia> _, vi = DynamicPPL.sample!!(rng, m, SimpleVarInfo()); +julia> _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo()); julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞ 1.8632965762164932 -julia> _, vi = DynamicPPL.sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)); +julia> _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ -0.21080155351918753 -julia> xs = [last(DynamicPPL.sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true julia> # And with `OrderedDict` of course! - _, vi = DynamicPPL.sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict()), true)); + _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict()), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ 0.6225185067787314 -julia> xs = [last(DynamicPPL.sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true @@ -259,12 +259,12 @@ end function untyped_simple_varinfo(model::Model) varinfo = SimpleVarInfo(OrderedDict()) - return last(sample!!(model, varinfo)) + return last(evaluate_and_sample!!(model, varinfo)) end function typed_simple_varinfo(model::Model) varinfo = SimpleVarInfo{Float64}() - return last(sample!!(model, varinfo)) + return last(evaluate_and_sample!!(model, varinfo)) end function unflatten(svi::SimpleVarInfo, x::AbstractVector) diff --git a/src/test_utils/model_interface.jl b/src/test_utils/model_interface.jl index 8d7d49d05..93aed074c 100644 --- a/src/test_utils/model_interface.jl +++ b/src/test_utils/model_interface.jl @@ -92,7 +92,9 @@ Even though it is recommended to implement this by hand for a particular `Model` a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. """ function varnames(model::Model) - return collect(keys(last(DynamicPPL.sample!!(model, SimpleVarInfo(Dict()))))) + return collect( + keys(last(DynamicPPL.evaluate_and_sample!!(model, SimpleVarInfo(Dict())))) + ) end """ diff --git a/src/varinfo.jl b/src/varinfo.jl index 4e6e61f66..b3380e7f9 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -197,7 +197,7 @@ Construct a VarInfo object for the given `model`, which has just a single function untyped_varinfo( rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() ) - return last(sample!!(rng, model, VarInfo(Metadata()), sampler)) + return last(evaluate_and_sample!!(rng, model, VarInfo(Metadata()), sampler)) end function untyped_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) return untyped_varinfo(Random.default_rng(), model, sampler) diff --git a/test/compiler.jl b/test/compiler.jl index de22d1b67..2d1342fea 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -598,13 +598,13 @@ module Issue537 end # an attempt at a `NamedTuple` of the form `(x = 1, __varinfo__)`. @model empty_model() = return x = 1 empty_vi = VarInfo() - retval_and_vi = DynamicPPL.sample!!(empty_model(), empty_vi) + retval_and_vi = DynamicPPL.evaluate_and_sample!!(empty_model(), empty_vi) @test retval_and_vi isa Tuple{Int,typeof(empty_vi)} # Even if the return-value is `AbstractVarInfo`, we should return # a `Tuple` with `AbstractVarInfo` in the second component too. @model demo() = return __varinfo__ - retval, svi = DynamicPPL.sample!!(demo(), SimpleVarInfo()) + retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) @test svi == SimpleVarInfo() if Threads.nthreads() > 1 @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} @@ -620,11 +620,11 @@ module Issue537 end f(x) = return x^2 return f(1.0) end - retval, svi = DynamicPPL.sample!!(demo(), SimpleVarInfo()) + retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) @test retval isa Float64 @model demo() = x ~ Normal() - retval, svi = DynamicPPL.sample!!(demo(), SimpleVarInfo()) + retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) # Return-value when using `to_submodel` @model inner() = x ~ Normal() diff --git a/test/model.jl b/test/model.jl index 574d276d6..daa3cc743 100644 --- a/test/model.jl +++ b/test/model.jl @@ -162,12 +162,12 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() for i in 1:10 Random.seed!(100 + i) vi = VarInfo() - DynamicPPL.sample!!(Random.default_rng(), model, vi, sampler) + DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler) vals = vi[:] Random.seed!(100 + i) vi = VarInfo() - DynamicPPL.sample!!(Random.default_rng(), model, vi, sampler) + DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler) @test vi[:] == vals end end @@ -332,7 +332,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test logjoint(model, x) != DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...) # Ensure `varnames` is implemented. - vi = last(DynamicPPL.sample!!(model, SimpleVarInfo(OrderedDict()))) + vi = last(DynamicPPL.evaluate_and_sample!!(model, SimpleVarInfo(OrderedDict()))) @test all(collect(keys(vi)) .== DynamicPPL.TestUtils.varnames(model)) # Ensure `posterior_mean` is implemented. @test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x) @@ -591,7 +591,10 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() xs_train = 1:0.1:10 ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train)) m_lin_reg = linear_reg(xs_train, ys_train) - chain = [last(DynamicPPL.sample!!(m_lin_reg, VarInfo())) for _ in 1:10000] + chain = [ + last(DynamicPPL.evaluate_and_sample!!(m_lin_reg, VarInfo())) for + _ in 1:10000 + ] # chain is generated from the prior @test mean([chain[i][@varname(β)] for i in eachindex(chain)]) ≈ 1.0 atol = 0.1 diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 2c9986dcc..e300c651e 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -158,7 +158,7 @@ ### Sampling ### # Sample a new varinfo! - _, svi_new = DynamicPPL.sample!!(model, svi) + _, svi_new = DynamicPPL.evaluate_and_sample!!(model, svi) # Realization for `m` should be different wp. 1. for vn in DynamicPPL.TestUtils.varnames(model) @@ -226,9 +226,9 @@ # Initialize. svi_nt = DynamicPPL.settrans!!(SimpleVarInfo(), true) - svi_nt = last(DynamicPPL.sample!!(model, svi_nt)) + svi_nt = last(DynamicPPL.evaluate_and_sample!!(model, svi_nt)) svi_vnv = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - svi_vnv = last(DynamicPPL.sample!!(model, svi_vnv)) + svi_vnv = last(DynamicPPL.evaluate_and_sample!!(model, svi_vnv)) for svi in (svi_nt, svi_vnv) # Sample with large variations in unconstrained space. @@ -273,7 +273,7 @@ ) # Resulting varinfo should no longer be transformed. - vi_result = last(DynamicPPL.sample!!(model, deepcopy(vi))) + vi_result = last(DynamicPPL.evaluate_and_sample!!(model, deepcopy(vi))) @test !DynamicPPL.istrans(vi_result) # Set the values to something that is out of domain if we're in constrained space. diff --git a/test/varinfo.jl b/test/varinfo.jl index 975ec9498..d788e6215 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -498,7 +498,7 @@ end # without having to change the VarInfo object. vi = VarInfo() meta = vi.metadata - _, vi = DynamicPPL.sample!!(model, vi, SampleFromUniform()) + _, vi = DynamicPPL.evaluate_and_sample!!(model, vi, SampleFromUniform()) @test all(x -> !istrans(vi, x), meta.vns) # Check that linking and invlinking set the `trans` flag accordingly @@ -572,7 +572,7 @@ end vi = DynamicPPL.untyped_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. - vi = last(DynamicPPL.sample!!(model, vi)) + vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -581,7 +581,7 @@ end vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. - vi = last(DynamicPPL.sample!!(model, vi)) + vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -590,7 +590,7 @@ end vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. - vi = last(DynamicPPL.sample!!(model, vi)) + vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -599,7 +599,7 @@ end ## `SimpleVarInfo{<:NamedTuple}` vi = DynamicPPL.settrans!!(SimpleVarInfo(), true) # Sample in unconstrained space. - vi = last(DynamicPPL.sample!!(model, vi)) + vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -607,7 +607,7 @@ end ## `SimpleVarInfo{<:Dict}` vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true) # Sample in unconstrained space. - vi = last(DynamicPPL.sample!!(model, vi)) + vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -615,7 +615,7 @@ end ## `SimpleVarInfo{<:VarNamedVector}` vi = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) # Sample in unconstrained space. - vi = last(DynamicPPL.sample!!(model, vi)) + vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -1000,7 +1000,7 @@ end # Sampling from `model2` should hit the `istrans(vi) == true` branches # because all the existing variables are linked. model2 = demo(2) - varinfo2 = last(DynamicPPL.sample!!(model2, deepcopy(varinfo1))) + varinfo2 = last(DynamicPPL.evaluate_and_sample!!(model2, deepcopy(varinfo1))) for vn in [@varname(x[1]), @varname(x[2])] @test DynamicPPL.istrans(varinfo2, vn) end @@ -1019,7 +1019,7 @@ end # Sampling from `model2` should hit the `istrans(vi) == true` branches # because all the existing variables are linked. model2 = demo_dot(2) - varinfo2 = last(DynamicPPL.sample!!(model2, deepcopy(varinfo1))) + varinfo2 = last(DynamicPPL.evaluate_and_sample!!(model2, deepcopy(varinfo1))) for vn in [@varname(x), @varname(y[1])] @test DynamicPPL.istrans(varinfo2, vn) end diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index 33fb1a162..57a8175d4 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -610,7 +610,9 @@ end DynamicPPL.TestUtils.test_values(varinfo_eval, value_true, vns) # Is sampling correct? - varinfo_sample = last(DynamicPPL.sample!!(model, deepcopy(varinfo))) + varinfo_sample = last( + DynamicPPL.evaluate_and_sample!!(model, deepcopy(varinfo)) + ) # Log density should be different. @test getlogjoint(varinfo_sample) != getlogjoint(varinfo) # Values should be different.