diff --git a/HISTORY.md b/HISTORY.md index 450365f1d..fcd005579 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,75 @@ # DynamicPPL Changelog +## 0.37.0 + +**Breaking changes** + +### Submodel macro + +The `@submodel` macro is fully removed; please use `to_submodel` instead. + +### `DynamicPPL.TestUtils.AD.run_ad` + +The three keyword arguments, `test`, `reference_backend`, and `expected_value_and_grad` have been merged into a single `test` keyword argument. +Please see the API documentation for more details. +(The old `test=true` and `test=false` values are still valid, and you only need to adjust the invocation if you were explicitly passing the `reference_backend` or `expected_value_and_grad` arguments.) + +There is now also an `rng` keyword argument to help seed parameter generation. + +Finally, instead of specifying `value_atol` and `grad_atol`, you can now specify `atol` and `rtol` which are used for both value and gradient. +Their semantics are the same as in Julia's `isapprox`; two values are equal if they satisfy either `atol` or `rtol`. + +### Accumulators + +This release overhauls how VarInfo objects track variables such as the log joint probability. The new approach is to use what we call accumulators: Objects that the VarInfo carries on it that may change their state at each `tilde_assume!!` and `tilde_observe!!` call based on the value of the variable in question. They replace both variables that were previously hard-coded in the `VarInfo` object (`logp` and `num_produce`) and some contexts. This brings with it a number of breaking changes: + + - `PriorContext` and `LikelihoodContext` no longer exist. By default, a `VarInfo` tracks both the log prior and the log likelihood separately, and they can be accessed with `getlogprior` and `getloglikelihood`. If you want to execute a model while only accumulating one of the two (to save clock cycles), you can do so by creating a `VarInfo` that only has one accumulator in it, e.g. `varinfo = setaccs!!(varinfo, (LogPriorAccumulator(),))`. + - `MiniBatchContext` does not exist anymore. It can be replaced by creating and using a custom accumulator that replaces the default `LikelihoodContext`. We may introduce such an accumulator in DynamicPPL in the future, but for now you'll need to do it yourself. + - `tilde_observe` and `observe` have been removed. `tilde_observe!!` still exists, and any contexts should modify its behaviour. We may further rework the call stack under `tilde_observe!!` in the near future. + - `tilde_assume` no longer returns the log density of the current assumption as its second return value. We may further rework the `tilde_assume!!` call stack as well. + - For literal observation statements like `0.0 ~ Normal(blahblah)` we used to call `tilde_observe!!` without the `vn` argument. This method no longer exists. Rather we call `tilde_observe!!` with `vn` set to `nothing`. + - `set/reset/increment_num_produce!` have become `set/reset/increment_num_produce!!` (note the second exclamation mark). They are no longer guaranteed to modify the `VarInfo` in place, and one should always use the return value. + - `@addlogprob!` now _always_ adds to the log likelihood. Previously it added to the log probability that the execution context specified, e.g. the log prior when using `PriorContext`. + - `getlogp` now returns a `NamedTuple` with keys `logprior` and `loglikelihood`. If you want the log joint probability, which is what `getlogp` used to return, use `getlogjoint`. + - Correspondingly `setlogp!!` and `acclogp!!` should now be called with a `NamedTuple` with keys `logprior` and `loglikelihood`. The `acclogp!!` method with a single scalar value has been deprecated and falls back on `accloglikelihood!!`, and the single scalar version of `setlogp!!` has been removed. Corresponding setter/accumulator functions exist for the log prior as well. + +### Evaluation contexts + +Historically, evaluating a DynamicPPL model has required three arguments: a model, some kind of VarInfo, and a context. +It's less known, though, that since DynamicPPL 0.14.0 the _model_ itself actually contains a context as well. +This version therefore excises the context argument, and instead uses `model.context` as the evaluation context. + +The upshot of this is that many functions that previously took a context argument now no longer do. +There were very few such functions where the context argument was actually used (most of them simply took `DefaultContext()` as the default value). + +`evaluate!!(model, varinfo, ext_context)` is removed, and broadly speaking you should replace calls to that with `new_model = contextualize(model, ext_context); evaluate!!(new_model, varinfo)`. +If the 'external context' `ext_context` is a parent context, then you should wrap `model.context` appropriately to ensure that its information content is not lost. +If, on the other hand, `ext_context` is a `DefaultContext`, then you can just drop the argument entirely. + +**To aid with this process, `contextualize` is now exported from DynamicPPL.** + +The main situation where one _did_ want to specify an additional evaluation context was when that context was a `SamplingContext`. +Doing this would allow you to run the model and sample fresh values, instead of just using the values that existed in the VarInfo object. +Thus, this release also introduces the **unexported** function `evaluate_and_sample!!`. +Essentially, `evaluate_and_sample!!(rng, model, varinfo, sampler)` is a drop-in replacement for `evaluate!!(model, varinfo, SamplingContext(rng, sampler))`. +**Do note that this is an internal method**, and its name or semantics are liable to change in the future without warning. + +There are many methods that no longer take a context argument, and listing them all would be too much. +However, here are the more user-facing ones: + + - `LogDensityFunction` no longer has a context field (or type parameter) + - `DynamicPPL.TestUtils.AD.run_ad` no longer uses a context (and the returned `ADResult` object no longer has a context field) + - `VarInfo(rng, model, sampler)` and other VarInfo constructors / functions that made VarInfos (e.g. `typed_varinfo`) from a model + - `(::Model)(args...)`: specifically, this now only takes `rng` and `varinfo` arguments (with both being optional) + - If you are using the `__context__` special variable inside a model, you will now have to use `__model__.context` instead + +And a couple of more internal changes: + + - Just like `evaluate!!`, the other functions `_evaluate!!`, `evaluate_threadsafe!!`, and `evaluate_threadunsafe!!` now no longer accept context arguments + - `evaluate!!` no longer takes rng and sampler (if you used this, you should use `evaluate_and_sample!!` instead, or construct your own `SamplingContext`) + - The model evaluation function, `model.f` for some `model::Model`, no longer takes a context as an argument + - The internal representation and API dealing with submodels (i.e., `ReturnedModelWrapper`, `Sampleable`, `should_auto_prefix`, `is_rhs_model`) has been simplified. If you need to check whether something is a submodel, just use `x isa DynamicPPL.Submodel`. Note that the public API i.e. `to_submodel` remains completely untouched. + ## 0.36.15 Bumped minimum Julia version to 1.10.8 to avoid potential crashes with `Core.Compiler.widenconst` (which Mooncake uses). diff --git a/Project.toml b/Project.toml index 63c07ed1a..c23845b8c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.36.15" +version = "0.37.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -21,6 +21,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -68,6 +69,7 @@ MCMCChains = "6, 7" MacroTools = "0.5.6" Mooncake = "0.4.95" OrderedCollections = "1" +Printf = "1.10" Random = "1.6" Requires = "1" Statistics = "1" diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 2b3bfbbdd..3d14d03ff 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -15,11 +15,14 @@ PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +[sources] +DynamicPPL = {path = "../"} + [compat] ADTypes = "1.14.0" BenchmarkTools = "1.6.0" Distributions = "0.25.117" -DynamicPPL = "0.36" +DynamicPPL = "0.37" ForwardDiff = "0.10.38, 1" LogDensityProblems = "2.1.2" Mooncake = "0.4" diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index 89b65d2de..b733d810c 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -1,6 +1,4 @@ using Pkg -# To ensure we benchmark the local version of DynamicPPL, dev the folder above. -Pkg.develop(; path=joinpath(@__DIR__, "..")) using DynamicPPLBenchmarks: Models, make_suite, model_dimension using BenchmarkTools: @benchmark, median, run @@ -100,4 +98,5 @@ PrettyTables.pretty_table( header=header, tf=PrettyTables.tf_markdown, formatters=ft_printf("%.1f", [6, 7]), + crop=:none, # Always print the whole table, even if it doesn't fit in the terminal. ) diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 6f486e2f5..26ec35b65 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -81,13 +81,12 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked:: end adbackend = to_backend(adbackend) - context = DynamicPPL.DefaultContext() if islinked vi = DynamicPPL.link(vi, model) end - f = DynamicPPL.LogDensityFunction(model, vi, context; adtype=adbackend) + f = DynamicPPL.LogDensityFunction(model, vi; adtype=adbackend) # The parameters at which we evaluate f. θ = vi[:] diff --git a/docs/Project.toml b/docs/Project.toml index 3f258909a..5797a8fd1 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -20,7 +20,7 @@ DataStructures = "0.18" Distributions = "0.25" Documenter = "1" DocumenterMermaid = "0.1, 0.2" -DynamicPPL = "0.36" +DynamicPPL = "0.37" FillArrays = "0.13, 1" ForwardDiff = "0.10, 1" JET = "0.9, 0.10" diff --git a/docs/src/api.md b/docs/src/api.md index a1adcb21c..e918a095c 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -36,6 +36,12 @@ getargnames getmissings ``` +The context of a model can be set using [`contextualize`](@ref): + +```@docs +contextualize +``` + ## Evaluation With [`rand`](@ref) one can draw samples from the prior distribution of a [`Model`](@ref). @@ -140,27 +146,15 @@ to_submodel Note that a `[to_submodel](@ref)` is only sampleable; one cannot compute `logpdf` for its realizations. -In the past, one would instead embed sub-models using [`@submodel`](@ref), which has been deprecated since the introduction of [`to_submodel(model)`](@ref) - -```@docs -@submodel -``` - In the context of including models within models, it's also useful to prefix the variables in sub-models to avoid variable names clashing: ```@docs DynamicPPL.prefix ``` -Under the hood, [`to_submodel`](@ref) makes use of the following method to indicate that the model it's wrapping is a model over its return-values rather than something else - -```@docs -returned(::Model) -``` - ## Utilities -It is possible to manually increase (or decrease) the accumulated log density from within a model function. +It is possible to manually increase (or decrease) the accumulated log likelihood or prior from within a model function. ```@docs @addlogprob! @@ -212,6 +206,21 @@ To test and/or benchmark the performance of an AD backend on a model, DynamicPPL ```@docs DynamicPPL.TestUtils.AD.run_ad +``` + +The default test setting is to compare against ForwardDiff. +You can have more fine-grained control over how to test the AD backend using the following types: + +```@docs +DynamicPPL.TestUtils.AD.AbstractADCorrectnessTestSetting +DynamicPPL.TestUtils.AD.WithBackend +DynamicPPL.TestUtils.AD.WithExpectedResult +DynamicPPL.TestUtils.AD.NoTest +``` + +These are returned / thrown by the `run_ad` function: + +```@docs DynamicPPL.TestUtils.AD.ADResult DynamicPPL.TestUtils.AD.ADIncorrectException ``` @@ -329,9 +338,9 @@ The following functions were used for sequential Monte Carlo methods. ```@docs get_num_produce -set_num_produce! -increment_num_produce! -reset_num_produce! +set_num_produce!! +increment_num_produce!! +reset_num_produce!! setorder! set_retained_vns_del! ``` @@ -346,6 +355,22 @@ Base.empty! SimpleVarInfo ``` +### Accumulators + +The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators. + +```@docs +AbstractAccumulator +``` + +DynamicPPL provides the following default accumulators. + +```@docs +LogPriorAccumulator +LogLikelihoodAccumulator +NumProduceAccumulator +``` + ### Common API #### Accumulation of log-probabilities @@ -354,6 +379,13 @@ SimpleVarInfo getlogp setlogp!! acclogp!! +getlogjoint +getlogprior +setlogprior!! +acclogprior!! +getloglikelihood +setloglikelihood!! +accloglikelihood!! resetlogp!! ``` @@ -416,21 +448,26 @@ DynamicPPL.varname_and_value_leaves ### Evaluation Contexts -Internally, both sampling and evaluation of log densities are performed with [`AbstractPPL.evaluate!!`](@ref). +Internally, model evaluation is performed with [`AbstractPPL.evaluate!!`](@ref). ```@docs AbstractPPL.evaluate!! ``` -The behaviour of a model execution can be changed with evaluation contexts that are passed as additional argument to the model function. +This method mutates the `varinfo` used for execution. +By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`. +To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method: + +```@docs +DynamicPPL.evaluate_and_sample!! +``` + +The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model. Contexts are subtypes of `AbstractPPL.AbstractContext`. ```@docs SamplingContext DefaultContext -LikelihoodContext -PriorContext -MiniBatchContext PrefixContext ConditionContext ``` @@ -477,7 +514,3 @@ DynamicPPL.Experimental.is_suitable_varinfo ```@docs tilde_assume ``` - -```@docs -tilde_observe -``` diff --git a/ext/DynamicPPLForwardDiffExt.jl b/ext/DynamicPPLForwardDiffExt.jl index 6bd7a5d94..7ea51918f 100644 --- a/ext/DynamicPPLForwardDiffExt.jl +++ b/ext/DynamicPPLForwardDiffExt.jl @@ -11,7 +11,6 @@ function DynamicPPL.tweak_adtype( ad::ADTypes.AutoForwardDiff{chunk_size}, ::DynamicPPL.Model, vi::DynamicPPL.AbstractVarInfo, - ::DynamicPPL.AbstractContext, ) where {chunk_size} params = vi[:] diff --git a/ext/DynamicPPLJETExt.jl b/ext/DynamicPPLJETExt.jl index aa95093f2..760d17bb0 100644 --- a/ext/DynamicPPLJETExt.jl +++ b/ext/DynamicPPLJETExt.jl @@ -4,15 +4,10 @@ using DynamicPPL: DynamicPPL using JET: JET function DynamicPPL.Experimental.is_suitable_varinfo( - model::DynamicPPL.Model, - context::DynamicPPL.AbstractContext, - varinfo::DynamicPPL.AbstractVarInfo; - only_ddpl::Bool=true, + model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; only_ddpl::Bool=true ) # Let's make sure that both evaluation and sampling doesn't result in type errors. - f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, varinfo, context - ) + f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(model, varinfo) # If specified, we only check errors originating somewhere in the DynamicPPL.jl. # This way we don't just fall back to untyped if the user's code is the issue. result = if only_ddpl @@ -24,14 +19,19 @@ function DynamicPPL.Experimental.is_suitable_varinfo( end function DynamicPPL.Experimental._determine_varinfo_jet( - model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_ddpl::Bool=true + model::DynamicPPL.Model; only_ddpl::Bool=true ) + # Use SamplingContext to test type stability. + sampling_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(model.context) + ) + # First we try with the typed varinfo. - varinfo = DynamicPPL.typed_varinfo(model, context) + varinfo = DynamicPPL.typed_varinfo(sampling_model) # Let's make sure that both evaluation and sampling doesn't result in type errors. issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo( - model, context, varinfo; only_ddpl + sampling_model, varinfo; only_ddpl ) if !issuccess @@ -46,7 +46,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet( else # Warn the user that we can't use the type stable one. @warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo." - DynamicPPL.untyped_varinfo(model, context) + DynamicPPL.untyped_varinfo(sampling_model) end end diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 7fcbd6a7c..a29696720 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -48,10 +48,10 @@ end Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample in `chain`, and return the resulting `Chains`. -The `model` passed to `predict` is often different from the one used to generate `chain`. -Typically, the model from which `chain` originated treats certain variables as observed (i.e., -data points), while the model you pass to `predict` may mark these same variables as missing -or unobserved. Calling `predict` then leverages the previously inferred parameter values to +The `model` passed to `predict` is often different from the one used to generate `chain`. +Typically, the model from which `chain` originated treats certain variables as observed (i.e., +data points), while the model you pass to `predict` may mark these same variables as missing +or unobserved. Calling `predict` then leverages the previously inferred parameter values to simulate what new, unobserved data might look like, given your posterior beliefs. For each parameter configuration in `chain`: @@ -59,7 +59,7 @@ For each parameter configuration in `chain`: 2. Any variables not included in `chain` are sampled from their prior distributions. If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by -the samples in `chain`. This is useful when you want to sample only new variables from the posterior +the samples in `chain`. This is useful when you want to sample only new variables from the posterior predictive distribution. # Examples @@ -115,7 +115,7 @@ function DynamicPPL.predict( iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) predictive_samples = map(iters) do (sample_idx, chain_idx) DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx) - model(rng, varinfo, DynamicPPL.SampleFromPrior()) + varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo)) vals = DynamicPPL.values_as_in_model(model, false, varinfo) varname_vals = mapreduce( @@ -124,7 +124,7 @@ function DynamicPPL.predict( map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)), ) - return (varname_and_values=varname_vals, logp=DynamicPPL.getlogp(varinfo)) + return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo)) end chain_result = reduce( diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 21f9044cd..69e489ce6 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -6,6 +6,7 @@ using Bijectors using Compat using Distributions using OrderedCollections: OrderedCollections, OrderedDict +using Printf: Printf using AbstractMCMC: AbstractMCMC using ADTypes: ADTypes @@ -46,17 +47,28 @@ import Base: export AbstractVarInfo, VarInfo, SimpleVarInfo, + AbstractAccumulator, + LogLikelihoodAccumulator, + LogPriorAccumulator, + NumProduceAccumulator, push!!, empty!!, subset, getlogp, + getlogjoint, + getlogprior, + getloglikelihood, setlogp!!, + setlogprior!!, + setloglikelihood!!, acclogp!!, + acclogprior!!, + accloglikelihood!!, resetlogp!!, get_num_produce, - set_num_produce!, - reset_num_produce!, - increment_num_produce!, + set_num_produce!!, + reset_num_produce!!, + increment_num_produce!!, set_retained_vns_del!, is_flagged, set_flag!, @@ -90,17 +102,13 @@ export AbstractVarInfo, # LogDensityFunction LogDensityFunction, # Contexts + contextualize, SamplingContext, DefaultContext, - LikelihoodContext, - PriorContext, - MiniBatchContext, PrefixContext, ConditionContext, assume, - observe, tilde_assume, - tilde_observe, # Pseudo distributions NamedDist, NoDist, @@ -120,7 +128,6 @@ export AbstractVarInfo, to_submodel, # Convenience macros @addlogprob!, - @submodel, value_iterator_from_chain, check_model, check_model_and_trace, @@ -146,6 +153,9 @@ macro prob_str(str) )) end +# TODO(mhauru) We should write down the list of methods that any subtype of AbstractVarInfo +# has to implement. Not sure what the full list is for parameters values, but for +# accumulators we only need `getaccs` and `setaccs!!`. """ AbstractVarInfo @@ -165,7 +175,10 @@ include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") +include("submodel.jl") include("varnamedvector.jl") +include("accumulators.jl") +include("default_accumulators.jl") include("abstract_varinfo.jl") include("threadsafe.jl") include("varinfo.jl") @@ -173,7 +186,6 @@ include("simple_varinfo.jl") include("context_implementations.jl") include("compiler.jl") include("pointwise_logdensities.jl") -include("submodel_macro.jl") include("transforming.jl") include("logdensityfunction.jl") include("model_utils.jl") @@ -214,6 +226,21 @@ if isdefined(Base.Experimental, :register_error_hint) ) end end + + Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ + is_evaluate_three_arg = + exc.f === AbstractPPL.evaluate!! && + length(argtypes) == 3 && + argtypes[1] <: Model && + argtypes[2] <: AbstractVarInfo && + argtypes[3] <: AbstractContext + if is_evaluate_three_arg + print( + io, + "\n\nThe method `evaluate!!(model, varinfo, new_ctx)` has been removed. Instead, you should store the `new_ctx` in the `model.context` field using `new_model = contextualize(model, new_ctx)`, and then call `evaluate!!(new_model, varinfo)` on the new model. (Note that, if the model already contained a non-default context, you will need to wrap the existing context.)", + ) + end + end end end diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 28bf488fa..68d3f9c03 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -90,45 +90,289 @@ Return the `AbstractTransformation` related to `vi`. function transformation end # Accumulation of log-probabilities. +""" + getlogjoint(vi::AbstractVarInfo) + +Return the log of the joint probability of the observed data and parameters in `vi`. + +See also: [`getlogprior`](@ref), [`getloglikelihood`](@ref). +""" +getlogjoint(vi::AbstractVarInfo) = getlogprior(vi) + getloglikelihood(vi) + """ getlogp(vi::AbstractVarInfo) -Return the log of the joint probability of the observed data and parameters sampled in -`vi`. +Return a NamedTuple of the log prior and log likelihood probabilities. + +The keys are called `logprior` and `loglikelihood`. If either one is not present in `vi` an +error will be thrown. +""" +function getlogp(vi::AbstractVarInfo) + return (; logprior=getlogprior(vi), loglikelihood=getloglikelihood(vi)) +end + +""" + setaccs!!(vi::AbstractVarInfo, accs::AccumulatorTuple) + setaccs!!(vi::AbstractVarInfo, accs::NTuple{N,AbstractAccumulator} where {N}) + +Update the `AccumulatorTuple` of `vi` to `accs`, mutating if it makes sense. + +`setaccs!!(vi:AbstractVarInfo, accs::AccumulatorTuple) should be implemented by each subtype +of `AbstractVarInfo`. +""" +function setaccs!!(vi::AbstractVarInfo, accs::NTuple{N,AbstractAccumulator}) where {N} + return setaccs!!(vi, AccumulatorTuple(accs)) +end + +""" + getaccs(vi::AbstractVarInfo) + +Return the `AccumulatorTuple` of `vi`. + +This should be implemented by each subtype of `AbstractVarInfo`. +""" +function getaccs end + +""" + hasacc(vi::AbstractVarInfo, ::Val{accname}) where {accname} + +Return a boolean for whether `vi` has an accumulator with name `accname`. +""" +hasacc(vi::AbstractVarInfo, accname::Val) = haskey(getaccs(vi), accname) +function hasacc(vi::AbstractVarInfo, accname::Symbol) + return error( + """ + The method hasacc(vi::AbstractVarInfo, accname::Symbol) does not exist. For type + stability reasons use hasacc(vi::AbstractVarInfo, Val(accname)) instead. + """ + ) +end + +""" + acckeys(vi::AbstractVarInfo) + +Return the names of the accumulators in `vi`. +""" +acckeys(vi::AbstractVarInfo) = keys(getaccs(vi)) + +""" + getlogprior(vi::AbstractVarInfo) + +Return the log of the prior probability of the parameters in `vi`. + +See also: [`getlogjoint`](@ref), [`getloglikelihood`](@ref), [`setlogprior!!`](@ref). +""" +getlogprior(vi::AbstractVarInfo) = getacc(vi, Val(:LogPrior)).logp + +""" + getloglikelihood(vi::AbstractVarInfo) + +Return the log of the likelihood probability of the observed data in `vi`. + +See also: [`getlogjoint`](@ref), [`getlogprior`](@ref), [`setloglikelihood!!`](@ref). +""" +getloglikelihood(vi::AbstractVarInfo) = getacc(vi, Val(:LogLikelihood)).logp + +""" + setacc!!(vi::AbstractVarInfo, acc::AbstractAccumulator) + +Add `acc` to the `AccumulatorTuple` of `vi`, mutating if it makes sense. + +If an accumulator with the same [`accumulator_name`](@ref) already exists, it will be +replaced. + +See also: [`getaccs`](@ref). +""" +function setacc!!(vi::AbstractVarInfo, acc::AbstractAccumulator) + return setaccs!!(vi, setacc!!(getaccs(vi), acc)) +end + +""" + setlogprior!!(vi::AbstractVarInfo, logp) + +Set the log of the prior probability of the parameters sampled in `vi` to `logp`. + +See also: [`setloglikelihood!!`](@ref), [`setlogp!!`](@ref), [`getlogprior`](@ref). +""" +setlogprior!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogPriorAccumulator(logp)) + +""" + setloglikelihood!!(vi::AbstractVarInfo, logp) + +Set the log of the likelihood probability of the observed data sampled in `vi` to `logp`. + +See also: [`setlogprior!!`](@ref), [`setlogp!!`](@ref), [`getloglikelihood`](@ref). +""" +setloglikelihood!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogLikelihoodAccumulator(logp)) + +""" + setlogp!!(vi::AbstractVarInfo, logp::NamedTuple) + +Set both the log prior and the log likelihood probabilities in `vi`. + +`logp` should have fields `logprior` and `loglikelihood` and no other fields. + +See also: [`setlogprior!!`](@ref), [`setloglikelihood!!`](@ref), [`getlogp`](@ref). +""" +function setlogp!!(vi::AbstractVarInfo, logp::NamedTuple{names}) where {names} + if !(names == (:logprior, :loglikelihood) || names == (:loglikelihood, :logprior)) + error("logp must have the fields logprior and loglikelihood and no other fields.") + end + vi = setlogprior!!(vi, logp.logprior) + vi = setloglikelihood!!(vi, logp.loglikelihood) + return vi +end + +function setlogp!!(vi::AbstractVarInfo, logp::Number) + return error(""" + `setlogp!!(vi::AbstractVarInfo, logp::Number)` is no longer supported. Use + `setloglikelihood!!` and/or `setlogprior!!` instead. + """) +end + +""" + getacc(vi::AbstractVarInfo, ::Val{accname}) + +Return the `AbstractAccumulator` of `vi` with name `accname`. +""" +function getacc(vi::AbstractVarInfo, accname::Val) + return getacc(getaccs(vi), accname) +end +function getacc(vi::AbstractVarInfo, accname::Symbol) + return error( + """ + The method getacc(vi::AbstractVarInfo, accname::Symbol) does not exist. For type + stability reasons use getacc(vi::AbstractVarInfo, Val(accname)) instead. + """ + ) +end + +""" + accumulate_assume!!(vi::AbstractVarInfo, val, logjac, vn, right) + +Update all the accumulators of `vi` by calling `accumulate_assume!!` on them. +""" +function accumulate_assume!!(vi::AbstractVarInfo, val, logjac, vn, right) + return map_accumulators!!(acc -> accumulate_assume!!(acc, val, logjac, vn, right), vi) +end + +""" + accumulate_observe!!(vi::AbstractVarInfo, right, left, vn) + +Update all the accumulators of `vi` by calling `accumulate_observe!!` on them. """ -function getlogp end +function accumulate_observe!!(vi::AbstractVarInfo, right, left, vn) + return map_accumulators!!(acc -> accumulate_observe!!(acc, right, left, vn), vi) +end """ - setlogp!!(vi::AbstractVarInfo, logp) + map_accumulators!!(func::Function, vi::AbstractVarInfo) -Set the log of the joint probability of the observed data and parameters sampled in -`vi` to `logp`, mutating if it makes sense. +Update all accumulators of `vi` by calling `func` on them and replacing them with the return +values. """ -function setlogp!! end +function map_accumulators!!(func::Function, vi::AbstractVarInfo) + return setaccs!!(vi, map(func, getaccs(vi))) +end """ - acclogp!!([context::AbstractContext, ]vi::AbstractVarInfo, logp) + map_accumulator!!(func::Function, vi::AbstractVarInfo, ::Val{accname}) where {accname} -Add `logp` to the value of the log of the joint probability of the observed data and -parameters sampled in `vi`, mutating if it makes sense. +Update the accumulator `accname` of `vi` by calling `func` on it and replacing it with the +return value. """ -function acclogp!!(context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(NodeTrait(context), context, vi, logp) +function map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Val) + return setaccs!!(vi, map_accumulator(func, getaccs(vi), accname)) +end + +function map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Symbol) + return error( + """ + The method map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Symbol) + does not exist. For type stability reasons use + map_accumulator!!(func::Function, vi::AbstractVarInfo, ::Val{accname}) instead. + """ + ) end -function acclogp!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(vi, logp) + +""" + acclogprior!!(vi::AbstractVarInfo, logp) + +Add `logp` to the value of the log of the prior probability in `vi`. + +See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getlogprior`](@ref), [`setlogprior!!`](@ref). +""" +function acclogprior!!(vi::AbstractVarInfo, logp) + return map_accumulator!!(acc -> acc + LogPriorAccumulator(logp), vi, Val(:LogPrior)) end -function acclogp!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(childcontext(context), vi, logp) + +""" + accloglikelihood!!(vi::AbstractVarInfo, logp) + +Add `logp` to the value of the log of the likelihood in `vi`. + +See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getloglikelihood`](@ref), [`setloglikelihood!!`](@ref). +""" +function accloglikelihood!!(vi::AbstractVarInfo, logp) + return map_accumulator!!( + acc -> acc + LogLikelihoodAccumulator(logp), vi, Val(:LogLikelihood) + ) +end + +""" + acclogp!!(vi::AbstractVarInfo, logp::NamedTuple; ignore_missing_accumulator::Bool=false) + +Add to both the log prior and the log likelihood probabilities in `vi`. + +`logp` should have fields `logprior` and/or `loglikelihood`, and no other fields. + +By default if the necessary accumulators are not in `vi` an error is thrown. If +`ignore_missing_accumulator` is set to `true` then this is silently ignored instead. +""" +function acclogp!!( + vi::AbstractVarInfo, logp::NamedTuple{names}; ignore_missing_accumulator=false +) where {names} + if !( + names == (:logprior, :loglikelihood) || + names == (:loglikelihood, :logprior) || + names == (:logprior,) || + names == (:loglikelihood,) + ) + error("logp must have fields logprior and/or loglikelihood and no other fields.") + end + if haskey(logp, :logprior) && + (!ignore_missing_accumulator || hasacc(vi, Val(:LogPrior))) + vi = acclogprior!!(vi, logp.logprior) + end + if haskey(logp, :loglikelihood) && + (!ignore_missing_accumulator || hasacc(vi, Val(:LogLikelihood))) + vi = accloglikelihood!!(vi, logp.loglikelihood) + end + return vi +end + +function acclogp!!(vi::AbstractVarInfo, logp::Number) + Base.depwarn( + "`acclogp!!(vi::AbstractVarInfo, logp::Number)` is deprecated. Use `accloglikelihood!!(vi, logp)` instead.", + :acclogp, + ) + return accloglikelihood!!(vi, logp) end """ resetlogp!!(vi::AbstractVarInfo) -Reset the value of the log of the joint probability of the observed data and parameters -sampled in `vi` to 0, mutating if it makes sense. +Reset the values of the log probabilities (prior and likelihood) in `vi` to zero. """ -resetlogp!!(vi::AbstractVarInfo) = setlogp!!(vi, zero(getlogp(vi))) +function resetlogp!!(vi::AbstractVarInfo) + if hasacc(vi, Val(:LogPrior)) + vi = map_accumulator!!(zero, vi, Val(:LogPrior)) + end + if hasacc(vi, Val(:LogLikelihood)) + vi = map_accumulator!!(zero, vi, Val(:LogLikelihood)) + end + return vi +end # Variables and their realizations. @doc """ @@ -574,8 +818,8 @@ function link!!( x = vi[:] y, logjac = with_logabsdet_jacobian(b, x) - lp_new = getlogp(vi) - logjac - vi_new = setlogp!!(unflatten(vi, y), lp_new) + lp_new = getlogprior(vi) - logjac + vi_new = setlogprior!!(unflatten(vi, y), lp_new) return settrans!!(vi_new, t) end @@ -586,8 +830,8 @@ function invlink!!( y = vi[:] x, logjac = with_logabsdet_jacobian(b, y) - lp_new = getlogp(vi) + logjac - vi_new = setlogp!!(unflatten(vi, x), lp_new) + lp_new = getlogprior(vi) + logjac + vi_new = setlogprior!!(unflatten(vi, x), lp_new) return settrans!!(vi_new, NoTransformation()) end @@ -731,9 +975,34 @@ function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist, y) return x, logpdf(dist, x) + logjac end -# Legacy code that is currently overloaded for the sake of simplicity. -# TODO: Remove when possible. -increment_num_produce!(::AbstractVarInfo) = nothing +""" + get_num_produce(vi::AbstractVarInfo) + +Return the `num_produce` of `vi`. +""" +get_num_produce(vi::AbstractVarInfo) = getacc(vi, Val(:NumProduce)).num + +""" + set_num_produce!!(vi::AbstractVarInfo, n::Int) + +Set the `num_produce` field of `vi` to `n`. +""" +set_num_produce!!(vi::AbstractVarInfo, n::Int) = setacc!!(vi, NumProduceAccumulator(n)) + +""" + increment_num_produce!!(vi::AbstractVarInfo) + +Add 1 to `num_produce` in `vi`. +""" +increment_num_produce!!(vi::AbstractVarInfo) = + map_accumulator!!(increment, vi, Val(:NumProduce)) + +""" + reset_num_produce!!(vi::AbstractVarInfo) + +Reset the value of `num_produce` in `vi` to 0. +""" +reset_num_produce!!(vi::AbstractVarInfo) = map_accumulator!!(zero, vi, Val(:NumProduce)) """ from_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist]) diff --git a/src/accumulators.jl b/src/accumulators.jl new file mode 100644 index 000000000..10a988ae5 --- /dev/null +++ b/src/accumulators.jl @@ -0,0 +1,189 @@ +""" + AbstractAccumulator + +An abstract type for accumulators. + +An accumulator is an object that may change its value at every tilde_assume!! or +tilde_observe!! call based on the random variable in question. The obvious examples of +accumulators are the log prior and log likelihood. Other examples might be a variable that +counts the number of observations in a trace, or a list of the names of random variables +seen so far. + +An accumulator type `T <: AbstractAccumulator` must implement the following methods: +- `accumulator_name(acc::T)` or `accumulator_name(::Type{T})` +- `accumulate_observe!!(acc::T, right, left, vn)` +- `accumulate_assume!!(acc::T, val, logjac, vn, right)` + +To be able to work with multi-threading, it should also implement: +- `split(acc::T)` +- `combine(acc::T, acc2::T)` + +See the documentation for each of these functions for more details. +""" +abstract type AbstractAccumulator end + +""" + accumulator_name(acc::AbstractAccumulator) + +Return a Symbol which can be used as a name for `acc`. + +The name has to be unique in the sense that a `VarInfo` can only have one accumulator for +each name. The most typical case, and the default implementation, is that the name only +depends on the type of `acc`, not on its value. +""" +accumulator_name(acc::AbstractAccumulator) = accumulator_name(typeof(acc)) + +""" + accumulate_observe!!(acc::AbstractAccumulator, right, left, vn) + +Update `acc` in a `tilde_observe!!` call. Returns the updated `acc`. + +`vn` is the name of the variable being observed, `left` is the value of the variable, and +`right` is the distribution on the RHS of the tilde statement. `vn` is `nothing` in the case +of literal observations like `0.0 ~ Normal()`. + +`accumulate_observe!!` may mutate `acc`, but not any of the other arguments. + +See also: [`accumulate_assume!!`](@ref) +""" +function accumulate_observe!! end + +""" + accumulate_assume!!(acc::AbstractAccumulator, val, logjac, vn, right) + +Update `acc` in a `tilde_assume!!` call. Returns the updated `acc`. + +`vn` is the name of the variable being assumed, `val` is the value of the variable, and +`right` is the distribution on the RHS of the tilde statement. `logjac` is the log +determinant of the Jacobian of the transformation that was done to convert the value of `vn` +as it was given (e.g. by sampler operating in linked space) to `val`. + +`accumulate_assume!!` may mutate `acc`, but not any of the other arguments. + +See also: [`accumulate_observe!!`](@ref) +""" +function accumulate_assume!! end + +""" + split(acc::AbstractAccumulator) + +Return a new accumulator like `acc` but empty. + +The precise meaning of "empty" is that that the returned value should be such that +`combine(acc, split(acc))` is equal to `acc`. This is used in the context of multi-threading +where different threads may accumulate independently and the results are the combined. + +See also: [`combine`](@ref) +""" +function split end + +""" + combine(acc::AbstractAccumulator, acc2::AbstractAccumulator) + +Combine two accumulators of the same type. Returns a new accumulator. + +See also: [`split`](@ref) +""" +function combine end + +# TODO(mhauru) The existence of this function makes me sad. See comment in unflatten in +# src/varinfo.jl. +""" + convert_eltype(::Type{T}, acc::AbstractAccumulator) + +Convert `acc` to use element type `T`. + +What "element type" means depends on the type of `acc`. By default this function does +nothing. Accumulator types that need to hold differentiable values, such as dual numbers +used by various AD backends, should implement a method for this function. +""" +convert_eltype(::Type, acc::AbstractAccumulator) = acc + +""" + AccumulatorTuple{N,T<:NamedTuple} + +A collection of accumulators, stored as a `NamedTuple` of length `N` + +This is defined as a separate type to be able to dispatch on it cleanly and without method +ambiguities or conflicts with other `NamedTuple` types. We also use this type to enforce the +constraint that the name in the tuple for each accumulator `acc` must be +`accumulator_name(acc)`, and these names must be unique. + +The constructor can be called with a tuple or a `VarArgs` of `AbstractAccumulators`. The +names will be generated automatically. One can also call the constructor with a `NamedTuple` +but the names in the argument will be discarded in favour of the generated ones. +""" +struct AccumulatorTuple{N,T<:NamedTuple} + nt::T + + function AccumulatorTuple(t::T) where {N,T<:NTuple{N,AbstractAccumulator}} + names = map(accumulator_name, t) + nt = NamedTuple{names}(t) + return new{N,typeof(nt)}(nt) + end +end + +AccumulatorTuple(accs::Vararg{AbstractAccumulator}) = AccumulatorTuple(accs) +AccumulatorTuple(nt::NamedTuple) = AccumulatorTuple(tuple(nt...)) + +# When showing with text/plain, leave out information about the wrapper AccumulatorTuple. +Base.show(io::IO, mime::MIME"text/plain", at::AccumulatorTuple) = show(io, mime, at.nt) +Base.getindex(at::AccumulatorTuple, idx) = at.nt[idx] +Base.length(::AccumulatorTuple{N}) where {N} = N +Base.iterate(at::AccumulatorTuple, args...) = iterate(at.nt, args...) +function Base.haskey(at::AccumulatorTuple, ::Val{accname}) where {accname} + # @inline to ensure constant propagation can resolve this to a compile-time constant. + @inline return haskey(at.nt, accname) +end +Base.keys(at::AccumulatorTuple) = keys(at.nt) + +function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N}) where {N,T} + return AccumulatorTuple(convert(T, accs.nt)) +end + +""" + setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator) + +Add `acc` to `at`. Returns a new `AccumulatorTuple`. + +If an `AbstractAccumulator` with the same `accumulator_name` already exists in `at` it is +replaced. `at` will never be mutated, but the name has the `!!` for consistency with the +corresponding function for `AbstractVarInfo`. +""" +function setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator) + accname = accumulator_name(acc) + new_nt = merge(at.nt, NamedTuple{(accname,)}((acc,))) + return AccumulatorTuple(new_nt) +end + +""" + getacc(at::AccumulatorTuple, ::Val{accname}) + +Get the accumulator with name `accname` from `at`. +""" +function getacc(at::AccumulatorTuple, ::Val{accname}) where {accname} + return at[accname] +end + +function Base.map(func::Function, at::AccumulatorTuple) + return AccumulatorTuple(map(func, at.nt)) +end + +""" + map_accumulator(func::Function, at::AccumulatorTuple, ::Val{accname}) + +Update the accumulator with name `accname` in `at` by calling `func` on it. + +Returns a new `AccumulatorTuple`. +""" +function map_accumulator( + func::Function, at::AccumulatorTuple, ::Val{accname} +) where {accname} + # Would like to write this as + # return Accessors.@set at.nt[accname] = func(at[accname], args...) + # for readability, but that one isn't type stable due to + # https://github.com/JuliaObjects/Accessors.jl/issues/198 + new_val = func(at[accname]) + new_nt = merge(at.nt, NamedTuple{(accname,)}((new_val,))) + return AccumulatorTuple(new_nt) +end diff --git a/src/compiler.jl b/src/compiler.jl index 6f7489b8e..6384eaa7c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1,4 +1,4 @@ -const INTERNALNAMES = (:__model__, :__context__, :__varinfo__) +const INTERNALNAMES = (:__model__, :__varinfo__) """ need_concretize(expr) @@ -29,6 +29,18 @@ function need_concretize(expr) end end +""" + make_varname_expression(expr) + +Return a `VarName` based on `expr`, concretizing it if necessary. +""" +function make_varname_expression(expr) + # HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact + # that in DynamicPPL we the entire function body. Instead we should be + # more selective with our escape. Until that's the case, we remove them all. + return AbstractPPL.drop_escape(varname(expr, need_concretize(expr))) +end + """ isassumption(expr[, vn]) @@ -48,15 +60,12 @@ evaluates to a `VarName`, and this will be used in the subsequent checks. If `vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be used in its place. """ -function isassumption( - expr::Union{Expr,Symbol}, - vn=AbstractPPL.drop_escape(varname(expr, need_concretize(expr))), -) +function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr)) return quote if $(DynamicPPL.contextual_isassumption)( - __context__, $(DynamicPPL.prefix)(__context__, $vn) + __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) ) - # Considered an assumption by `__context__` which means either: + # Considered an assumption by `__model__.context` which means either: # 1. We hit the default implementation, e.g. using `DefaultContext`, # which in turn means that we haven't considered if it's one of # the model arguments, hence we need to check this. @@ -107,7 +116,7 @@ end isfixed(expr, vn) = false function isfixed(::Union{Symbol,Expr}, vn) return :($(DynamicPPL.contextual_isfixed)( - __context__, $(DynamicPPL.prefix)(__context__, $vn) + __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) )) end @@ -167,11 +176,7 @@ function check_tilde_rhs(@nospecialize(x)) end check_tilde_rhs(x::Distribution) = x check_tilde_rhs(x::AbstractArray{<:Distribution}) = x -check_tilde_rhs(x::ReturnedModelWrapper) = x -function check_tilde_rhs(x::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix} - model = check_tilde_rhs(x.model) - return Sampleable{typeof(model),AutoPrefix}(model) -end +check_tilde_rhs(x::Submodel{M,AutoPrefix}) where {M,AutoPrefix} = x """ check_dot_tilde_rhs(x) @@ -402,14 +407,18 @@ function generate_mainbody!(mod, found, expr::Expr, warn) end function generate_assign(left, right) - right_expr = :($(TrackedValue)($right)) - tilde_expr = generate_tilde(left, right_expr) + # A statement `x := y` reduces to `x = y`, but if __varinfo__ has an accumulator for + # ValuesAsInModel then in addition we push! the pair of `x` and `y` to the accumulator. + @gensym acc right_val vn return quote - if $(is_extracting_values)(__context__) - $tilde_expr - else - $left = $right + $right_val = $right + if $(DynamicPPL.is_extracting_values)(__varinfo__) + $vn = $(DynamicPPL.prefix)(__model__.context, $(make_varname_expression(left))) + __varinfo__ = $(map_accumulator!!)( + $acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel) + ) end + $left = $right_val end end @@ -418,7 +427,11 @@ function generate_tilde_literal(left, right) @gensym value return quote $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( - __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ + __model__.context, + $(DynamicPPL.check_tilde_rhs)($right), + $left, + nothing, + __varinfo__, ) $value end @@ -437,18 +450,13 @@ function generate_tilde(left, right) # if the LHS represents an observation @gensym vn isassumption value dist - # HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact - # that in DynamicPPL we the entire function body. Instead we should be - # more selective with our escape. Until that's the case, we remove them all. return quote $dist = $right - $vn = $(DynamicPPL.resolve_varnames)( - $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $dist - ) + $vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist) $isassumption = $(DynamicPPL.isassumption(left, vn)) if $(DynamicPPL.isfixed(left, vn)) $left = $(DynamicPPL.getfixed_nested)( - __context__, $(DynamicPPL.prefix)(__context__, $vn) + __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) ) elseif $isassumption $(generate_tilde_assume(left, dist, vn)) @@ -456,12 +464,12 @@ function generate_tilde(left, right) # If `vn` is not in `argnames`, we need to make sure that the variable is defined. if !$(DynamicPPL.inargnames)($vn, __model__) $left = $(DynamicPPL.getconditioned_nested)( - __context__, $(DynamicPPL.prefix)(__context__, $vn) + __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) ) end $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( - __context__, + __model__.context, $(DynamicPPL.check_tilde_rhs)($dist), $(maybe_view(left)), $vn, @@ -486,7 +494,7 @@ function generate_tilde_assume(left, right, vn) return quote $value, __varinfo__ = $(DynamicPPL.tilde_assume!!)( - __context__, + __model__.context, $(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)..., __varinfo__, ) @@ -644,11 +652,7 @@ function build_output(modeldef, linenumbernode) # Add the internal arguments to the user-specified arguments (positional + keywords). evaluatordef[:args] = vcat( - [ - :(__model__::$(DynamicPPL.Model)), - :(__varinfo__::$(DynamicPPL.AbstractVarInfo)), - :(__context__::$(DynamicPPL.AbstractContext)), - ], + [:(__model__::$(DynamicPPL.Model)), :(__varinfo__::$(DynamicPPL.AbstractVarInfo))], args, ) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 3ee88149e..b11a723a5 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -1,24 +1,3 @@ -# Allows samplers, etc. to hook into the final logp accumulation in the tilde-pipeline. -function acclogp_assume!!(context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp_assume!!(NodeTrait(acclogp_assume!!, context), context, vi, logp) -end -function acclogp_assume!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp_assume!!(childcontext(context), vi, logp) -end -function acclogp_assume!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(context, vi, logp) -end - -function acclogp_observe!!(context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp_observe!!(NodeTrait(acclogp_observe!!, context), context, vi, logp) -end -function acclogp_observe!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp_observe!!(childcontext(context), vi, logp) -end -function acclogp_observe!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(context, vi, logp) -end - # assume """ tilde_assume(context::SamplingContext, right, vn, vi) @@ -36,44 +15,23 @@ function tilde_assume(context::SamplingContext, right, vn, vi) return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) end -# Leaf contexts function tilde_assume(context::AbstractContext, args...) - return tilde_assume(NodeTrait(tilde_assume, context), context, args...) + return tilde_assume(childcontext(context), args...) end -function tilde_assume(::IsLeaf, context::AbstractContext, right, vn, vi) - # no rng nor sampler +function tilde_assume(::DefaultContext, right, vn, vi) return assume(right, vn, vi) end -function tilde_assume(::IsParent, context::AbstractContext, args...) - return tilde_assume(childcontext(context), args...) -end function tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...) - return tilde_assume(NodeTrait(tilde_assume, context), rng, context, args...) + return tilde_assume(rng, childcontext(context), args...) end -function tilde_assume( - ::IsLeaf, rng::Random.AbstractRNG, context::AbstractContext, sampler, right, vn, vi -) - # rng and sampler +function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi) return assume(rng, sampler, right, vn, vi) end -function tilde_assume(::IsLeaf, context::AbstractContext, sampler, right, vn, vi) - # sampler but no rng +function tilde_assume(::DefaultContext, sampler, right, vn, vi) + # same as above but no rng return assume(Random.default_rng(), sampler, right, vn, vi) end -function tilde_assume( - ::IsParent, rng::Random.AbstractRNG, context::AbstractContext, args... -) - # rng but no sampler - return tilde_assume(rng, childcontext(context), args...) -end - -function tilde_assume(::LikelihoodContext, right, vn, vi) - return assume(nodist(right), vn, vi) -end -function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, vi) - return assume(rng, sampler, nodist(right), vn, vi) -end function tilde_assume(context::PrefixContext, right, vn, vi) # Note that we can't use something like this here: @@ -105,78 +63,44 @@ By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log probability of `vi` with the returned value. """ function tilde_assume!!(context, right, vn, vi) - return if is_rhs_model(right) - # Here, we apply the PrefixContext _not_ to the parent `context`, but - # to the context of the submodel being evaluated. This means that later= - # on in `make_evaluate_args_and_kwargs`, the context stack will be - # correctly arranged such that it goes like this: - # parent_context[1] -> parent_context[2] -> ... -> PrefixContext -> - # submodel_context[1] -> submodel_context[2] -> ... -> leafcontext - # See the docstring of `make_evaluate_args_and_kwargs`, and the internal - # DynamicPPL documentation on submodel conditioning, for more details. - # - # NOTE: This relies on the existence of `right.model.model`. Right now, - # the only thing that can return true for `is_rhs_model` is something - # (a `Sampleable`) that has a `model` field that itself (a - # `ReturnedModelWrapper`) has a `model` field. This may or may not - # change in the future. - if should_auto_prefix(right) - dppl_model = right.model.model # This isa DynamicPPL.Model - prefixed_submodel_context = PrefixContext(vn, dppl_model.context) - new_dppl_model = contextualize(dppl_model, prefixed_submodel_context) - right = to_submodel(new_dppl_model, true) - end - rand_like!!(right, context, vi) + return if right isa DynamicPPL.Submodel + _evaluate!!(right, vi, context, vn) else - value, logp, vi = tilde_assume(context, right, vn, vi) - value, acclogp_assume!!(context, vi, logp) + tilde_assume(context, right, vn, vi) end end # observe """ - tilde_observe(context::SamplingContext, right, left, vi) + tilde_observe!!(context::SamplingContext, right, left, vi) Handle observed constants with a `context` associated with a sampler. -Falls back to `tilde_observe(context.context, context.sampler, right, left, vi)`. +Falls back to `tilde_observe!!(context.context, right, left, vi)`. """ -function tilde_observe(context::SamplingContext, right, left, vi) - return tilde_observe(context.context, context.sampler, right, left, vi) -end - -# Leaf contexts -function tilde_observe(context::AbstractContext, args...) - return tilde_observe(NodeTrait(tilde_observe, context), context, args...) +function tilde_observe!!(context::SamplingContext, right, left, vn, vi) + return tilde_observe!!(context.context, right, left, vn, vi) end -tilde_observe(::IsLeaf, context::AbstractContext, args...) = observe(args...) -function tilde_observe(::IsParent, context::AbstractContext, args...) - return tilde_observe(childcontext(context), args...) -end - -tilde_observe(::PriorContext, right, left, vi) = 0, vi -tilde_observe(::PriorContext, sampler, right, left, vi) = 0, vi -# `MiniBatchContext` -function tilde_observe(context::MiniBatchContext, right, left, vi) - logp, vi = tilde_observe(context.context, right, left, vi) - return context.loglike_scalar * logp, vi -end -function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) - logp, vi = tilde_observe(context.context, sampler, right, left, vi) - return context.loglike_scalar * logp, vi +function tilde_observe!!(context::AbstractContext, right, left, vn, vi) + return tilde_observe!!(childcontext(context), right, left, vn, vi) end # `PrefixContext` -function tilde_observe(context::PrefixContext, right, left, vi) - return tilde_observe(context.context, right, left, vi) -end -function tilde_observe(context::PrefixContext, sampler, right, left, vi) - return tilde_observe(context.context, sampler, right, left, vi) +function tilde_observe!!(context::PrefixContext, right, left, vn, vi) + # In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal + # value. For the need for prefix_and_strip_contexts rather than just prefix, see the + # comment in `tilde_assume!!`. + new_vn, new_context = if vn !== nothing + prefix_and_strip_contexts(context, vn) + else + vn, childcontext(context) + end + return tilde_observe!!(new_context, right, left, new_vn, vi) end """ - tilde_observe!!(context, right, left, vname, vi) + tilde_observe!!(context, right, left, vn, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value and updated `vi`. @@ -184,46 +108,24 @@ accumulate the log probability, and return the observed value and updated `vi`. Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!!(context, right, left, vname, vi) - is_rhs_model(right) && throw( - ArgumentError( - "`~` with a model on the right-hand side of an observe statement is not supported", - ), - ) - return tilde_observe!!(context, right, left, vi) -end - -""" - tilde_observe(context, right, left, vi) - -Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and -return the observed value. - -By default, calls `tilde_observe(context, right, left, vi)` and accumulates the log -probability of `vi` with the returned value. -""" -function tilde_observe!!(context, right, left, vi) - is_rhs_model(right) && throw( - ArgumentError( - "`~` with a model on the right-hand side of an observe statement is not supported", - ), - ) - logp, vi = tilde_observe(context, right, left, vi) - return left, acclogp_observe!!(context, vi, logp) +function tilde_observe!!(::DefaultContext, right, left, vn, vi) + right isa DynamicPPL.Submodel && + throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed")) + vi = accumulate_observe!!(vi, right, left, vn) + return left, vi end -function assume(rng::Random.AbstractRNG, spl::Sampler, dist) +function assume(::Random.AbstractRNG, spl::Sampler, dist) return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") end -function observe(spl::Sampler, weight) - return error("DynamicPPL.observe: unmanaged inference algorithm: $(typeof(spl))") -end - # fallback without sampler function assume(dist::Distribution, vn::VarName, vi) - r, logp = invlink_with_logpdf(vi, vn, dist) - return r, logp, vi + y = getindex_internal(vi, vn) + f = from_maybe_linked_internal_transform(vi, vn, dist) + x, logjac = with_logabsdet_jacobian(f, y) + vi = accumulate_assume!!(vi, x, logjac, vn, dist) + return x, vi end # TODO: Remove this thing. @@ -245,8 +147,7 @@ function assume( r = init(rng, dist, sampler) f = to_maybe_linked_internal_transform(vi, vn, dist) # TODO(mhauru) This should probably be call a function called setindex_internal! - # Also, if we use !! we shouldn't ignore the return value. - BangBang.setindex!!(vi, f(r), vn) + vi = BangBang.setindex!!(vi, f(r), vn) setorder!(vi, vn, get_num_produce(vi)) else # Otherwise we just extract it. @@ -256,22 +157,16 @@ function assume( r = init(rng, dist, sampler) if istrans(vi) f = to_linked_internal_transform(vi, vn, dist) - push!!(vi, vn, f(r), dist) + vi = push!!(vi, vn, f(r), dist) # By default `push!!` sets the transformed flag to `false`. - settrans!!(vi, true, vn) + vi = settrans!!(vi, true, vn) else - push!!(vi, vn, r, dist) + vi = push!!(vi, vn, r, dist) end end # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct. logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r) - return r, logpdf(dist, r) - logjac, vi -end - -# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`) -observe(sampler::AbstractSampler, right, left, vi) = observe(right, left, vi) -function observe(right::Distribution, left, vi) - increment_num_produce!(vi) - return Distributions.loglikelihood(right, left), vi + vi = accumulate_assume!!(vi, r, -logjac, vn, dist) + return r, vi end diff --git a/src/contexts.jl b/src/contexts.jl index 8ac085663..addadfa1a 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -45,15 +45,17 @@ effectively updating the child context. # Examples ```jldoctest +julia> using DynamicPPL: DynamicTransformationContext + julia> ctx = SamplingContext(); julia> DynamicPPL.childcontext(ctx) DefaultContext() -julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior +julia> ctx_prior = DynamicPPL.setchildcontext(ctx, DynamicTransformationContext{true}()); julia> DynamicPPL.childcontext(ctx_prior) -PriorContext() +DynamicTransformationContext{true}() ``` """ setchildcontext @@ -78,7 +80,7 @@ original leaf context of `left`. # Examples ```jldoctest -julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext +julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext, DynamicTransformationContext julia> struct ParentContext{C} <: AbstractContext context::C @@ -96,8 +98,8 @@ julia> ctx = ParentContext(ParentContext(DefaultContext())) ParentContext(ParentContext(DefaultContext())) julia> # Replace the leaf context with another leaf. - leafcontext(setleafcontext(ctx, PriorContext())) -PriorContext() + leafcontext(setleafcontext(ctx, DynamicTransformationContext{true}())) +DynamicTransformationContext{true}() julia> # Append another parent context. setleafcontext(ctx, ParentContext(DefaultContext())) @@ -129,7 +131,7 @@ setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right Create a context that allows you to sample parameters with the `sampler` when running the model. The `context` determines how the returned log density is computed when running the model. -See also: [`DefaultContext`](@ref), [`LikelihoodContext`](@ref), [`PriorContext`](@ref) +See also: [`DefaultContext`](@ref) """ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext rng::R @@ -189,52 +191,11 @@ getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context") """ struct DefaultContext <: AbstractContext end -The `DefaultContext` is used by default to compute the log joint probability of the data -and parameters when running the model. +The `DefaultContext` is used by default to accumulate values like the log joint probability +when running the model. """ struct DefaultContext <: AbstractContext end -NodeTrait(context::DefaultContext) = IsLeaf() - -""" - PriorContext <: AbstractContext - -A leaf context resulting in the exclusion of likelihood terms when running the model. -""" -struct PriorContext <: AbstractContext end -NodeTrait(context::PriorContext) = IsLeaf() - -""" - LikelihoodContext <: AbstractContext - -A leaf context resulting in the exclusion of prior terms when running the model. -""" -struct LikelihoodContext <: AbstractContext end -NodeTrait(context::LikelihoodContext) = IsLeaf() - -""" - struct MiniBatchContext{Tctx, T} <: AbstractContext - context::Tctx - loglike_scalar::T - end - -The `MiniBatchContext` enables the computation of -`log(prior) + s * log(likelihood of a batch)` when running the model, where `s` is the -`loglike_scalar` field, typically equal to `the number of data points / batch size`. -This is useful in batch-based stochastic gradient descent algorithms to be optimizing -`log(prior) + log(likelihood of all the data points)` in the expectation. -""" -struct MiniBatchContext{Tctx,T} <: AbstractContext - context::Tctx - loglike_scalar::T -end -function MiniBatchContext(context=DefaultContext(); batch_size, npoints) - return MiniBatchContext(context, npoints / batch_size) -end -NodeTrait(context::MiniBatchContext) = IsParent() -childcontext(context::MiniBatchContext) = context.context -function setchildcontext(parent::MiniBatchContext, child) - return MiniBatchContext(child, parent.loglike_scalar) -end +NodeTrait(::DefaultContext) = IsLeaf() """ PrefixContext(vn::VarName[, context::AbstractContext]) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 754b344ee..4343ce8ac 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -76,7 +76,6 @@ Base.@kwdef struct AssumeStmt <: Stmt varname right value - logp varinfo = nothing end @@ -89,16 +88,12 @@ function Base.show(io::IO, stmt::AssumeStmt) print(io, " ") print(io, RESULT_SYMBOL) print(io, " ") - print(io, stmt.value) - print(io, " (logprob = ") - print(io, stmt.logp) - return print(io, ")") + return print(io, stmt.value) end Base.@kwdef struct ObserveStmt <: Stmt left right - logp varinfo = nothing end @@ -107,10 +102,7 @@ function Base.show(io::IO, stmt::ObserveStmt) print(io, "observe: ") show_right(io, stmt.left) print(io, " ~ ") - show_right(io, stmt.right) - print(io, " (logprob = ") - print(io, stmt.logp) - return print(io, ")") + return show_right(io, stmt.right) end # Some utility methods for extracting information from a trace. @@ -139,9 +131,7 @@ A context used for checking validity of a model. # Fields $(FIELDS) """ -struct DebugContext{M<:Model,C<:AbstractContext} <: AbstractContext - "model that is being run" - model::M +struct DebugContext{C<:AbstractContext} <: AbstractContext "context used for running the model" context::C "mapping from varnames to the number of times they have been seen" @@ -157,7 +147,6 @@ struct DebugContext{M<:Model,C<:AbstractContext} <: AbstractContext end function DebugContext( - model::Model, context::AbstractContext=DefaultContext(); varnames_seen=OrderedDict{VarName,Int}(), statements=Vector{Stmt}(), @@ -166,7 +155,6 @@ function DebugContext( record_varinfo=false, ) return DebugContext( - model, context, varnames_seen, statements, @@ -252,12 +240,11 @@ function record_pre_tilde_assume!(context::DebugContext, vn, dist, varinfo) return nothing end -function record_post_tilde_assume!(context::DebugContext, vn, dist, value, logp, varinfo) +function record_post_tilde_assume!(context::DebugContext, vn, dist, value, varinfo) stmt = AssumeStmt(; varname=vn, right=dist, value=value, - logp=logp, varinfo=context.record_varinfo ? varinfo : nothing, ) if context.record_statements @@ -268,19 +255,17 @@ end function DynamicPPL.tilde_assume(context::DebugContext, right, vn, vi) record_pre_tilde_assume!(context, vn, right, vi) - value, logp, vi = DynamicPPL.tilde_assume(childcontext(context), right, vn, vi) - record_post_tilde_assume!(context, vn, right, value, logp, vi) - return value, logp, vi + value, vi = DynamicPPL.tilde_assume(childcontext(context), right, vn, vi) + record_post_tilde_assume!(context, vn, right, value, vi) + return value, vi end function DynamicPPL.tilde_assume( rng::Random.AbstractRNG, context::DebugContext, sampler, right, vn, vi ) record_pre_tilde_assume!(context, vn, right, vi) - value, logp, vi = DynamicPPL.tilde_assume( - rng, childcontext(context), sampler, right, vn, vi - ) - record_post_tilde_assume!(context, vn, right, value, logp, vi) - return value, logp, vi + value, vi = DynamicPPL.tilde_assume(rng, childcontext(context), sampler, right, vn, vi) + record_post_tilde_assume!(context, vn, right, value, vi) + return value, vi end # observe @@ -304,12 +289,9 @@ function record_pre_tilde_observe!(context::DebugContext, left, dist, varinfo) end end -function record_post_tilde_observe!(context::DebugContext, left, right, logp, varinfo) +function record_post_tilde_observe!(context::DebugContext, left, right, varinfo) stmt = ObserveStmt(; - left=left, - right=right, - logp=logp, - varinfo=context.record_varinfo ? varinfo : nothing, + left=left, right=right, varinfo=context.record_varinfo ? varinfo : nothing ) if context.record_statements push!(context.statements, stmt) @@ -317,17 +299,17 @@ function record_post_tilde_observe!(context::DebugContext, left, right, logp, va return nothing end -function DynamicPPL.tilde_observe(context::DebugContext, right, left, vi) +function DynamicPPL.tilde_observe!!(context::DebugContext, right, left, vn, vi) record_pre_tilde_observe!(context, left, right, vi) - logp, vi = DynamicPPL.tilde_observe(childcontext(context), right, left, vi) - record_post_tilde_observe!(context, left, right, logp, vi) - return logp, vi + vi = DynamicPPL.tilde_observe!!(childcontext(context), right, left, vn, vi) + record_post_tilde_observe!(context, left, right, vi) + return vi end -function DynamicPPL.tilde_observe(context::DebugContext, sampler, right, left, vi) +function DynamicPPL.tilde_observe!!(context::DebugContext, sampler, right, left, vn, vi) record_pre_tilde_observe!(context, left, right, vi) - logp, vi = DynamicPPL.tilde_observe(childcontext(context), sampler, right, left, vi) - record_post_tilde_observe!(context, left, right, logp, vi) - return logp, vi + vi = DynamicPPL.tilde_observe!!(childcontext(context), sampler, right, left, vn, vi) + record_post_tilde_observe!(context, left, right, vi) + return vi end _conditioned_varnames(d::AbstractDict) = keys(d) @@ -358,7 +340,7 @@ function check_varnames_seen(varnames_seen::AbstractDict{VarName,Int}) end # A check we run on the model before evaluating it. -function check_model_pre_evaluation(context::DebugContext, model::Model) +function check_model_pre_evaluation(model::Model) issuccess = true # If something is in the model arguments, then it should NOT be in `condition`, # nor should there be any symbol present in `condition` that has the same symbol. @@ -375,8 +357,8 @@ function check_model_pre_evaluation(context::DebugContext, model::Model) return issuccess end -function check_model_post_evaluation(context::DebugContext, model::Model) - return check_varnames_seen(context.varnames_seen) +function check_model_post_evaluation(model::Model) + return check_varnames_seen(model.context.varnames_seen) end """ @@ -418,7 +400,7 @@ julia> issuccess true julia> print(trace) - assume: x ~ Normal{Float64}(μ=0.0, σ=1.0) ⟼ -0.670252 (logprob = -1.14356) + assume: x ~ Normal{Float64}(μ=0.0, σ=1.0) ⟼ -0.670252 julia> issuccess, trace = check_model_and_trace(rng, demo_correct() | (x = 1.0,)); ┌ Warning: The model does not contain any parameters. @@ -428,7 +410,7 @@ julia> issuccess true julia> print(trace) -observe: 1.0 ~ Normal{Float64}(μ=0.0, σ=1.0) (logprob = -1.41894) +observe: 1.0 ~ Normal{Float64}(μ=0.0, σ=1.0) ``` ## Incorrect model @@ -452,25 +434,23 @@ function check_model_and_trace( rng::Random.AbstractRNG, model::Model; varinfo=VarInfo(), - context=SamplingContext(rng), error_on_failure=false, kwargs..., ) # Execute the model with the debug context. debug_context = DebugContext( - model, context; error_on_failure=error_on_failure, kwargs... + SamplingContext(rng, model.context); error_on_failure=error_on_failure, kwargs... ) + debug_model = DynamicPPL.contextualize(model, debug_context) # Perform checks before evaluating the model. - issuccess = check_model_pre_evaluation(debug_context, model) + issuccess = check_model_pre_evaluation(debug_model) # Force single-threaded execution. - retval, varinfo_result = DynamicPPL.evaluate_threadunsafe!!( - model, varinfo, debug_context - ) + DynamicPPL.evaluate_threadunsafe!!(debug_model, varinfo) # Perform checks after evaluating the model. - issuccess &= check_model_post_evaluation(debug_context, model) + issuccess &= check_model_post_evaluation(debug_model) if !issuccess && error_on_failure error("model check failed") @@ -549,14 +529,13 @@ function has_static_constraints( end """ - gen_evaluator_call_with_types(model[, varinfo, context]) + gen_evaluator_call_with_types(model[, varinfo]) Generate the evaluator call and the types of the arguments. # Arguments - `model::Model`: The model whose evaluator is of interest. - `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`. -- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref). # Returns A 2-tuple with the following elements: @@ -565,11 +544,9 @@ A 2-tuple with the following elements: - `argtypes::Type{<:Tuple}`: The types of the arguments for the evaluator. """ function gen_evaluator_call_with_types( - model::Model, - varinfo::AbstractVarInfo=VarInfo(model), - context::AbstractContext=DefaultContext(), + model::Model, varinfo::AbstractVarInfo=VarInfo(model) ) - args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) + args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo) return if isempty(kwargs) (model.f, Base.typesof(args...)) else @@ -578,7 +555,7 @@ function gen_evaluator_call_with_types( end """ - model_warntype(model[, varinfo, context]; optimize=true) + model_warntype(model[, varinfo]; optimize=true) Check the type stability of the model's evaluator, warning about any potential issues. @@ -587,23 +564,19 @@ This simply calls `@code_warntype` on the model's evaluator, filling in internal # Arguments - `model::Model`: The model to check. - `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`. -- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref). # Keyword Arguments - `optimize::Bool`: Whether to generate optimized code. Default: `false`. """ function model_warntype( - model::Model, - varinfo::AbstractVarInfo=VarInfo(model), - context::AbstractContext=DefaultContext(); - optimize::Bool=false, + model::Model, varinfo::AbstractVarInfo=VarInfo(model), optimize::Bool=false ) - ftype, argtypes = gen_evaluator_call_with_types(model, varinfo, context) + ftype, argtypes = gen_evaluator_call_with_types(model, varinfo) return InteractiveUtils.code_warntype(ftype, argtypes; optimize=optimize) end """ - model_typed(model[, varinfo, context]; optimize=true) + model_typed(model[, varinfo]; optimize=true) Return the type inference for the model's evaluator. @@ -612,18 +585,14 @@ This simply calls `@code_typed` on the model's evaluator, filling in internal ar # Arguments - `model::Model`: The model to check. - `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`. -- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref). # Keyword Arguments - `optimize::Bool`: Whether to generate optimized code. Default: `true`. """ function model_typed( - model::Model, - varinfo::AbstractVarInfo=VarInfo(model), - context::AbstractContext=DefaultContext(); - optimize::Bool=true, + model::Model, varinfo::AbstractVarInfo=VarInfo(model), optimize::Bool=true ) - ftype, argtypes = gen_evaluator_call_with_types(model, varinfo, context) + ftype, argtypes = gen_evaluator_call_with_types(model, varinfo) return only(InteractiveUtils.code_typed(ftype, argtypes; optimize=optimize)) end diff --git a/src/default_accumulators.jl b/src/default_accumulators.jl new file mode 100644 index 000000000..ab538ba51 --- /dev/null +++ b/src/default_accumulators.jl @@ -0,0 +1,154 @@ +""" + LogPriorAccumulator{T<:Real} <: AbstractAccumulator + +An accumulator that tracks the cumulative log prior during model execution. + +# Fields +$(TYPEDFIELDS) +""" +struct LogPriorAccumulator{T<:Real} <: AbstractAccumulator + "the scalar log prior value" + logp::T +end + +""" + LogPriorAccumulator{T}() + +Create a new `LogPriorAccumulator` accumulator with the log prior initialized to zero. +""" +LogPriorAccumulator{T}() where {T<:Real} = LogPriorAccumulator(zero(T)) +LogPriorAccumulator() = LogPriorAccumulator{LogProbType}() + +""" + LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator + +An accumulator that tracks the cumulative log likelihood during model execution. + +# Fields +$(TYPEDFIELDS) +""" +struct LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator + "the scalar log likelihood value" + logp::T +end + +""" + LogLikelihoodAccumulator{T}() + +Create a new `LogLikelihoodAccumulator` accumulator with the log likelihood initialized to zero. +""" +LogLikelihoodAccumulator{T}() where {T<:Real} = LogLikelihoodAccumulator(zero(T)) +LogLikelihoodAccumulator() = LogLikelihoodAccumulator{LogProbType}() + +""" + NumProduceAccumulator{T} <: AbstractAccumulator + +An accumulator that tracks the number of observations during model execution. + +# Fields +$(TYPEDFIELDS) +""" +struct NumProduceAccumulator{T<:Integer} <: AbstractAccumulator + "the number of observations" + num::T +end + +""" + NumProduceAccumulator{T<:Integer}() + +Create a new `NumProduceAccumulator` accumulator with the number of observations initialized to zero. +""" +NumProduceAccumulator{T}() where {T<:Integer} = NumProduceAccumulator(zero(T)) +NumProduceAccumulator() = NumProduceAccumulator{Int}() + +function Base.show(io::IO, acc::LogPriorAccumulator) + return print(io, "LogPriorAccumulator($(repr(acc.logp)))") +end +function Base.show(io::IO, acc::LogLikelihoodAccumulator) + return print(io, "LogLikelihoodAccumulator($(repr(acc.logp)))") +end +function Base.show(io::IO, acc::NumProduceAccumulator) + return print(io, "NumProduceAccumulator($(repr(acc.num)))") +end + +accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior +accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood +accumulator_name(::Type{<:NumProduceAccumulator}) = :NumProduce + +split(::LogPriorAccumulator{T}) where {T} = LogPriorAccumulator(zero(T)) +split(::LogLikelihoodAccumulator{T}) where {T} = LogLikelihoodAccumulator(zero(T)) +split(acc::NumProduceAccumulator) = acc + +function combine(acc::LogPriorAccumulator, acc2::LogPriorAccumulator) + return LogPriorAccumulator(acc.logp + acc2.logp) +end +function combine(acc::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) + return LogLikelihoodAccumulator(acc.logp + acc2.logp) +end +function combine(acc::NumProduceAccumulator, acc2::NumProduceAccumulator) + return NumProduceAccumulator(max(acc.num, acc2.num)) +end + +function Base.:+(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) + return LogPriorAccumulator(acc1.logp + acc2.logp) +end +function Base.:+(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) + return LogLikelihoodAccumulator(acc1.logp + acc2.logp) +end +increment(acc::NumProduceAccumulator) = NumProduceAccumulator(acc.num + oneunit(acc.num)) + +Base.zero(acc::LogPriorAccumulator) = LogPriorAccumulator(zero(acc.logp)) +Base.zero(acc::LogLikelihoodAccumulator) = LogLikelihoodAccumulator(zero(acc.logp)) +Base.zero(acc::NumProduceAccumulator) = NumProduceAccumulator(zero(acc.num)) + +function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right) + return acc + LogPriorAccumulator(logpdf(right, val) + logjac) +end +accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc + +accumulate_assume!!(acc::LogLikelihoodAccumulator, val, logjac, vn, right) = acc +function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn) + # Note that it's important to use the loglikelihood function here, not logpdf, because + # they handle vectors differently: + # https://github.com/JuliaStats/Distributions.jl/issues/1972 + return acc + LogLikelihoodAccumulator(Distributions.loglikelihood(right, left)) +end + +accumulate_assume!!(acc::NumProduceAccumulator, val, logjac, vn, right) = acc +accumulate_observe!!(acc::NumProduceAccumulator, right, left, vn) = increment(acc) + +function Base.convert(::Type{LogPriorAccumulator{T}}, acc::LogPriorAccumulator) where {T} + return LogPriorAccumulator(convert(T, acc.logp)) +end +function Base.convert( + ::Type{LogLikelihoodAccumulator{T}}, acc::LogLikelihoodAccumulator +) where {T} + return LogLikelihoodAccumulator(convert(T, acc.logp)) +end +function Base.convert( + ::Type{NumProduceAccumulator{T}}, acc::NumProduceAccumulator +) where {T} + return NumProduceAccumulator(convert(T, acc.num)) +end + +# TODO(mhauru) +# We ignore the convert_eltype calls for NumProduceAccumulator, by letting them fallback on +# convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to +# deal with dual number types of AD backends, which shouldn't concern NumProduceAccumulator. This is +# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`. +function convert_eltype(::Type{T}, acc::LogPriorAccumulator) where {T} + return LogPriorAccumulator(convert(T, acc.logp)) +end +function convert_eltype(::Type{T}, acc::LogLikelihoodAccumulator) where {T} + return LogLikelihoodAccumulator(convert(T, acc.logp)) +end + +function default_accumulators( + ::Type{FloatT}=LogProbType, ::Type{IntT}=Int +) where {FloatT,IntT} + return AccumulatorTuple( + LogPriorAccumulator{FloatT}(), + LogLikelihoodAccumulator{FloatT}(), + NumProduceAccumulator{IntT}(), + ) +end diff --git a/src/experimental.jl b/src/experimental.jl index 84038803c..974912957 100644 --- a/src/experimental.jl +++ b/src/experimental.jl @@ -4,16 +4,15 @@ using DynamicPPL: DynamicPPL # This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency. """ - is_suitable_varinfo(model::Model, context::AbstractContext, varinfo::AbstractVarInfo; kwargs...) + is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...) -Check if the `model` supports evaluation using the provided `context` and `varinfo`. +Check if the `model` supports evaluation using the provided `varinfo`. !!! warning Loading JET.jl is required before calling this function. # Arguments - `model`: The model to verify the support for. -- `context`: The context to use for the model evaluation. - `varinfo`: The varinfo to verify the support for. # Keyword Arguments @@ -29,7 +28,7 @@ function is_suitable_varinfo end function _determine_varinfo_jet end """ - determine_suitable_varinfo(model[, context]; only_ddpl::Bool=true) + determine_suitable_varinfo(model; only_ddpl::Bool=true) Return a suitable varinfo for the given `model`. @@ -41,7 +40,6 @@ See also: [`DynamicPPL.Experimental.is_suitable_varinfo`](@ref). # Arguments - `model`: The model for which to determine the varinfo. -- `context`: The context to use for the model evaluation. Default: `SamplingContext()`. # Keyword Arguments - `only_ddpl`: If `true`, only consider error reports within DynamicPPL.jl. @@ -85,14 +83,10 @@ julia> vi isa typeof(DynamicPPL.typed_varinfo(model_with_static_support())) true ``` """ -function determine_suitable_varinfo( - model::DynamicPPL.Model, - context::DynamicPPL.AbstractContext=DynamicPPL.SamplingContext(); - only_ddpl::Bool=true, -) +function determine_suitable_varinfo(model::DynamicPPL.Model; only_ddpl::Bool=true) # If JET.jl has been loaded, and thus `determine_varinfo` has been defined, we use that. return if Base.get_extension(DynamicPPL, :DynamicPPLJETExt) !== nothing - _determine_varinfo_jet(model, context; only_ddpl) + _determine_varinfo_jet(model; only_ddpl) else # Warn the user. @warn "JET.jl is not loaded. Assumes the model is compatible with typed varinfo." diff --git a/src/extract_priors.jl b/src/extract_priors.jl index 0f312fa2c..bd6bdb2f2 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -1,44 +1,47 @@ -struct PriorExtractorContext{D<:OrderedDict{VarName,Any},Ctx<:AbstractContext} <: - AbstractContext +struct PriorDistributionAccumulator{D<:OrderedDict{VarName,Any}} <: AbstractAccumulator priors::D - context::Ctx end -PriorExtractorContext(context) = PriorExtractorContext(OrderedDict{VarName,Any}(), context) +PriorDistributionAccumulator() = PriorDistributionAccumulator(OrderedDict{VarName,Any}()) -NodeTrait(::PriorExtractorContext) = IsParent() -childcontext(context::PriorExtractorContext) = context.context -function setchildcontext(parent::PriorExtractorContext, child) - return PriorExtractorContext(parent.priors, child) +accumulator_name(::PriorDistributionAccumulator) = :PriorDistributionAccumulator + +split(acc::PriorDistributionAccumulator) = PriorDistributionAccumulator(empty(acc.priors)) +function combine(acc1::PriorDistributionAccumulator, acc2::PriorDistributionAccumulator) + return PriorDistributionAccumulator(merge(acc1.priors, acc2.priors)) end -function setprior!(context::PriorExtractorContext, vn::VarName, dist::Distribution) - return context.priors[vn] = dist +function setprior!(acc::PriorDistributionAccumulator, vn::VarName, dist::Distribution) + acc.priors[vn] = dist + return acc end function setprior!( - context::PriorExtractorContext, vns::AbstractArray{<:VarName}, dist::Distribution + acc::PriorDistributionAccumulator, vns::AbstractArray{<:VarName}, dist::Distribution ) for vn in vns - context.priors[vn] = dist + acc.priors[vn] = dist end + return acc end function setprior!( - context::PriorExtractorContext, + acc::PriorDistributionAccumulator, vns::AbstractArray{<:VarName}, dists::AbstractArray{<:Distribution}, ) for (vn, dist) in zip(vns, dists) - context.priors[vn] = dist + acc.priors[vn] = dist end + return acc end -function DynamicPPL.tilde_assume(context::PriorExtractorContext, right, vn, vi) - setprior!(context, vn, right) - return DynamicPPL.tilde_assume(childcontext(context), right, vn, vi) +function accumulate_assume!!(acc::PriorDistributionAccumulator, val, logjac, vn, right) + return setprior!(acc, vn, right) end +accumulate_observe!!(acc::PriorDistributionAccumulator, right, left, vn) = acc + """ extract_priors([rng::Random.AbstractRNG, ]model::Model) @@ -108,9 +111,13 @@ julia> length(extract_priors(rng, model)[@varname(x)]) extract_priors(args::Union{Model,AbstractVarInfo}...) = extract_priors(Random.default_rng(), args...) function extract_priors(rng::Random.AbstractRNG, model::Model) - context = PriorExtractorContext(SamplingContext(rng)) - evaluate!!(model, VarInfo(), context) - return context.priors + varinfo = VarInfo() + # TODO(mhauru) This doesn't actually need the NumProduceAccumulator, it's only a + # workaround for the fact that `order` is still hardcoded in VarInfo, and hence you + # can't push new variables without knowing the num_produce. Remove this when possible. + varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(), NumProduceAccumulator())) + varinfo = last(evaluate_and_sample!!(rng, model, varinfo)) + return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end """ @@ -122,7 +129,12 @@ This is done by evaluating the model at the values present in `varinfo` and recording the distributions that are present at each tilde statement. """ function extract_priors(model::Model, varinfo::AbstractVarInfo) - context = PriorExtractorContext(DefaultContext()) - evaluate!!(model, deepcopy(varinfo), context) - return context.priors + # TODO(mhauru) This doesn't actually need the NumProduceAccumulator, it's only a + # workaround for the fact that `order` is still hardcoded in VarInfo, and hence you + # can't push new variables without knowing the num_produce. Remove this when possible. + varinfo = setaccs!!( + deepcopy(varinfo), (PriorDistributionAccumulator(), NumProduceAccumulator()) + ) + varinfo = last(evaluate!!(model, varinfo)) + return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 06c188ed6..e7565d137 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -18,8 +18,7 @@ is_supported(::ADTypes.AutoReverseDiff) = true """ LogDensityFunction( model::Model, - varinfo::AbstractVarInfo=VarInfo(model), - context::AbstractContext=DefaultContext(); + varinfo::AbstractVarInfo=VarInfo(model); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing ) @@ -29,10 +28,9 @@ A struct which contains a model, along with all the information necessary to: - and if `adtype` is provided, calculate the gradient of the log density at that point. -At its most basic level, a LogDensityFunction wraps the model together with its -the type of varinfo to be used, as well as the evaluation context. These must -be known in order to calculate the log density (using -[`DynamicPPL.evaluate!!`](@ref)). +At its most basic level, a LogDensityFunction wraps the model together with the +type of varinfo to be used. These must be known in order to calculate the log +density (using [`DynamicPPL.evaluate!!`](@ref)). If the `adtype` keyword argument is provided, then this struct will also store the adtype along with other information for efficient calculation of the @@ -52,7 +50,7 @@ $(FIELDS) ```jldoctest julia> using Distributions -julia> using DynamicPPL: LogDensityFunction, contextualize +julia> using DynamicPPL: LogDensityFunction, setaccs!! julia> @model function demo(x) m ~ Normal() @@ -79,8 +77,8 @@ julia> # By default it uses `VarInfo` under the hood, but this is not necessary. julia> LogDensityProblems.logdensity(f, [0.0]) -2.3378770664093453 -julia> # This also respects the context in `model`. - f_prior = LogDensityFunction(contextualize(model, DynamicPPL.PriorContext()), VarInfo(model)); +julia> # LogDensityFunction respects the accumulators in VarInfo: + f_prior = LogDensityFunction(model, setaccs!!(VarInfo(model), (LogPriorAccumulator(),))); julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) true @@ -95,14 +93,12 @@ julia> LogDensityProblems.logdensity_and_gradient(f, [0.0]) ``` """ struct LogDensityFunction{ - M<:Model,V<:AbstractVarInfo,C<:AbstractContext,AD<:Union{Nothing,ADTypes.AbstractADType} + M<:Model,V<:AbstractVarInfo,AD<:Union{Nothing,ADTypes.AbstractADType} } <: AbstractModel "model used for evaluation" model::M "varinfo used for evaluation" varinfo::V - "context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable" - context::C "AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated" adtype::AD "(internal use only) gradient preparation object for the model" @@ -110,35 +106,29 @@ struct LogDensityFunction{ function LogDensityFunction( model::Model, - varinfo::AbstractVarInfo=VarInfo(model), - context::AbstractContext=leafcontext(model.context); + varinfo::AbstractVarInfo=VarInfo(model); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) if adtype === nothing prep = nothing else # Make backend-specific tweaks to the adtype - adtype = tweak_adtype(adtype, model, varinfo, context) + adtype = tweak_adtype(adtype, model, varinfo) # Check whether it is supported is_supported(adtype) || @warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed." # Get a set of dummy params to use for prep x = map(identity, varinfo[:]) if use_closure(adtype) - prep = DI.prepare_gradient(LogDensityAt(model, varinfo, context), adtype, x) + prep = DI.prepare_gradient(LogDensityAt(model, varinfo), adtype, x) else prep = DI.prepare_gradient( - logdensity_at, - adtype, - x, - DI.Constant(model), - DI.Constant(varinfo), - DI.Constant(context), + logdensity_at, adtype, x, DI.Constant(model), DI.Constant(varinfo) ) end end - return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}( - model, varinfo, context, adtype, prep + return new{typeof(model),typeof(varinfo),typeof(adtype)}( + model, varinfo, adtype, prep ) end end @@ -149,9 +139,9 @@ end adtype::Union{Nothing,ADTypes.AbstractADType} ) -Create a new LogDensityFunction using the model, varinfo, and context from the given -`ldf` argument, but with the AD type set to `adtype`. To remove the AD type, pass -`nothing` as the second argument. +Create a new LogDensityFunction using the model and varinfo from the given +`ldf` argument, but with the AD type set to `adtype`. To remove the AD type, +pass `nothing` as the second argument. """ function LogDensityFunction( f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType} @@ -159,7 +149,7 @@ function LogDensityFunction( return if adtype === f.adtype f # Avoid recomputing prep if not needed else - LogDensityFunction(f.model, f.varinfo, f.context; adtype=adtype) + LogDensityFunction(f.model, f.varinfo; adtype=adtype) end end @@ -168,77 +158,74 @@ end x::AbstractVector, model::Model, varinfo::AbstractVarInfo, - context::AbstractContext ) Evaluate the log density of the given `model` at the given parameter values `x`, -using the given `varinfo` and `context`. Note that the `varinfo` argument is provided -only for its structure, in the sense that the parameters from the vector `x` are inserted into -it, and its own parameters are discarded. +using the given `varinfo`. Note that the `varinfo` argument is provided only +for its structure, in the sense that the parameters from the vector `x` are +inserted into it, and its own parameters are discarded. It does, however, +determine whether the log prior, likelihood, or joint is returned, based on +which accumulators are set in it. """ -function logdensity_at( - x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext -) +function logdensity_at(x::AbstractVector, model::Model, varinfo::AbstractVarInfo) varinfo_new = unflatten(varinfo, x) - return getlogp(last(evaluate!!(model, varinfo_new, context))) + varinfo_eval = last(evaluate!!(model, varinfo_new)) + has_prior = hasacc(varinfo_eval, Val(:LogPrior)) + has_likelihood = hasacc(varinfo_eval, Val(:LogLikelihood)) + if has_prior && has_likelihood + return getlogjoint(varinfo_eval) + elseif has_prior + return getlogprior(varinfo_eval) + elseif has_likelihood + return getloglikelihood(varinfo_eval) + else + error("LogDensityFunction: varinfo tracks neither log prior nor log likelihood") + end end """ - LogDensityAt{M<:Model,V<:AbstractVarInfo,C<:AbstractContext}( + LogDensityAt{M<:Model,V<:AbstractVarInfo}( model::M varinfo::V - context::C ) A callable struct that serves the same purpose as `x -> logdensity_at(x, model, -varinfo, context)`. +varinfo)`. """ -struct LogDensityAt{M<:Model,V<:AbstractVarInfo,C<:AbstractContext} +struct LogDensityAt{M<:Model,V<:AbstractVarInfo} model::M varinfo::V - context::C -end -function (ld::LogDensityAt)(x::AbstractVector) - varinfo_new = unflatten(ld.varinfo, x) - return getlogp(last(evaluate!!(ld.model, varinfo_new, ld.context))) end +(ld::LogDensityAt)(x::AbstractVector) = logdensity_at(x, ld.model, ld.varinfo) ### LogDensityProblems interface function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,V,C,Nothing}} -) where {M,V,C} + ::Type{<:LogDensityFunction{M,V,Nothing}} +) where {M,V} return LogDensityProblems.LogDensityOrder{0}() end function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,V,C,AD}} -) where {M,V,C,AD<:ADTypes.AbstractADType} + ::Type{<:LogDensityFunction{M,V,AD}} +) where {M,V,AD<:ADTypes.AbstractADType} return LogDensityProblems.LogDensityOrder{1}() end function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector) - return logdensity_at(x, f.model, f.varinfo, f.context) + return logdensity_at(x, f.model, f.varinfo) end function LogDensityProblems.logdensity_and_gradient( - f::LogDensityFunction{M,V,C,AD}, x::AbstractVector -) where {M,V,C,AD<:ADTypes.AbstractADType} + f::LogDensityFunction{M,V,AD}, x::AbstractVector +) where {M,V,AD<:ADTypes.AbstractADType} f.prep === nothing && error("Gradient preparation not available; this should not happen") x = map(identity, x) # Concretise type # Make branching statically inferrable, i.e. type-stable (even if the two # branches happen to return different types) return if use_closure(f.adtype) - DI.value_and_gradient( - LogDensityAt(f.model, f.varinfo, f.context), f.prep, f.adtype, x - ) + DI.value_and_gradient(LogDensityAt(f.model, f.varinfo), f.prep, f.adtype, x) else DI.value_and_gradient( - logdensity_at, - f.prep, - f.adtype, - x, - DI.Constant(f.model), - DI.Constant(f.varinfo), - DI.Constant(f.context), + logdensity_at, f.prep, f.adtype, x, DI.Constant(f.model), DI.Constant(f.varinfo) ) end end @@ -253,7 +240,6 @@ LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) adtype::ADTypes.AbstractADType, model::Model, varinfo::AbstractVarInfo, - context::AbstractContext ) Return an 'optimised' form of the adtype. This is useful for doing @@ -264,9 +250,7 @@ model. By default, this just returns the input unchanged. """ -tweak_adtype( - adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo, ::AbstractContext -) = adtype +tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtype """ use_closure(adtype::ADTypes.AbstractADType) @@ -286,17 +270,14 @@ There are two ways of dealing with this: The relative performance of the two approaches, however, depends on the AD backend used. Some benchmarks are provided here: -https://github.com/TuringLang/DynamicPPL.jl/pull/806#issuecomment-2658061480 +https://github.com/TuringLang/DynamicPPL.jl/issues/946#issuecomment-2931604829 This function is used to determine whether a given AD backend should use a closure or a constant. If `use_closure(adtype)` returns `true`, then the closure approach will be used. By default, this function returns `false`, i.e. the constant approach will be used. """ -use_closure(::ADTypes.AbstractADType) = false -use_closure(::ADTypes.AutoForwardDiff) = false -use_closure(::ADTypes.AutoMooncake) = false -use_closure(::ADTypes.AutoReverseDiff) = true +use_closure(::ADTypes.AbstractADType) = true """ getmodel(f) @@ -311,7 +292,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. """ function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) - return LogDensityFunction(model, f.varinfo, f.context; adtype=f.adtype) + return LogDensityFunction(model, f.varinfo; adtype=f.adtype) end """ diff --git a/src/model.jl b/src/model.jl index c7c4bdf57..93e77eaec 100644 --- a/src/model.jl +++ b/src/model.jl @@ -85,6 +85,12 @@ function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); k return Model(f, args, NamedTuple(kwargs), context) end +""" + contextualize(model::Model, context::AbstractContext) + +Return a new `Model` with the same evaluation function and other arguments, but +with its underlying context set to `context`. +""" function contextualize(model::Model, context::AbstractContext) return Model(model.f, model.args, model.defaults, context) end @@ -252,7 +258,7 @@ julia> # However, it's not possible to condition `inner` directly. conditioned_model_fail = model | (inner = 1.0, ); julia> conditioned_model_fail() -ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported +ERROR: ArgumentError: `x ~ to_submodel(...)` is not supported when `x` is observed [...] ``` """ @@ -794,15 +800,23 @@ julia> # Now `a.x` will be sampled. fixed(model::Model) = fixed(model.context) """ - (model::Model)([rng, varinfo, sampler, context]) + (model::Model)([rng, varinfo]) -Sample from the `model` using the `sampler` with random number generator `rng` and the -`context`, and store the sample and log joint probability in `varinfo`. +Sample from the prior of the `model` with random number generator `rng`. -The method resets the log joint probability of `varinfo` and increases the evaluation -number of `sampler`. +Returns the model's return value. + +Note that calling this with an existing `varinfo` object will mutate it. """ -(model::Model)(args...) = first(evaluate!!(model, args...)) +(model::Model)() = model(Random.default_rng(), VarInfo()) +function (model::Model)(varinfo::AbstractVarInfo) + return model(Random.default_rng(), varinfo) +end +# ^ Weird Documenter.jl bug means that we have to write the two above separately +# as it can only detect the `function`-less syntax. +function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo()) + return first(evaluate_and_sample!!(rng, model, varinfo)) +end """ use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) @@ -815,65 +829,52 @@ function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) end """ - evaluate!!(model::Model[, rng, varinfo, sampler, context]) + evaluate_and_sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler]) -Sample from the `model` using the `sampler` with random number generator `rng` and the -`context`, and store the sample and log joint probability in `varinfo`. +Evaluate the `model` with the given `varinfo`, but perform sampling during the +evaluation using the given `sampler` by wrapping the model's context in a +`SamplingContext`. -Returns both the return-value of the original model, and the resulting varinfo. +If `sampler` is not provided, defaults to [`SampleFromPrior`](@ref). -The method resets the log joint probability of `varinfo` and increases the evaluation -number of `sampler`. +Returns a tuple of the model's return value, plus the updated `varinfo` object. """ -function AbstractPPL.evaluate!!( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext -) - return if use_threadsafe_eval(context, varinfo) - evaluate_threadsafe!!(model, varinfo, context) - else - evaluate_threadunsafe!!(model, varinfo, context) - end -end - -function AbstractPPL.evaluate!!( - model::Model, +function evaluate_and_sample!!( rng::Random.AbstractRNG, - varinfo::AbstractVarInfo=VarInfo(), + model::Model, + varinfo::AbstractVarInfo, sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), ) - return evaluate!!(model, varinfo, SamplingContext(rng, sampler, context)) + sampling_model = contextualize(model, SamplingContext(rng, sampler, model.context)) + return evaluate!!(sampling_model, varinfo) end - -function AbstractPPL.evaluate!!(model::Model, context::AbstractContext) - return evaluate!!(model, VarInfo(), context) -end - -function AbstractPPL.evaluate!!( - model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}... +function evaluate_and_sample!!( + model::Model, varinfo::AbstractVarInfo, sampler::AbstractSampler=SampleFromPrior() ) - return evaluate!!(model, Random.default_rng(), args...) + return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler) end -# without VarInfo -function AbstractPPL.evaluate!!( - model::Model, - rng::Random.AbstractRNG, - sampler::AbstractSampler, - args::AbstractContext..., -) - return evaluate!!(model, rng, VarInfo(), sampler, args...) -end +""" + evaluate!!(model::Model, varinfo) -# without VarInfo and without AbstractSampler -function AbstractPPL.evaluate!!( - model::Model, rng::Random.AbstractRNG, context::AbstractContext -) - return evaluate!!(model, rng, VarInfo(), SampleFromPrior(), context) +Evaluate the `model` with the given `varinfo`. + +If multiple threads are available, the varinfo provided will be wrapped in a +`ThreadSafeVarInfo` before evaluation. + +Returns a tuple of the model's return value, plus the updated `varinfo` +(unwrapped if necessary). +""" +function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) + return if use_threadsafe_eval(model.context, varinfo) + evaluate_threadsafe!!(model, varinfo) + else + evaluate_threadunsafe!!(model, varinfo) + end end """ - evaluate_threadunsafe!!(model, varinfo, context) + evaluate_threadunsafe!!(model, varinfo) Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. @@ -882,8 +883,8 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe!!`](@ref) """ -function evaluate_threadunsafe!!(model, varinfo, context) - return _evaluate!!(model, resetlogp!!(varinfo), context) +function evaluate_threadunsafe!!(model, varinfo) + return _evaluate!!(model, resetlogp!!(varinfo)) end """ @@ -897,31 +898,38 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe!!`](@ref) """ -function evaluate_threadsafe!!(model, varinfo, context) +function evaluate_threadsafe!!(model, varinfo) wrapper = ThreadSafeVarInfo(resetlogp!!(varinfo)) - result, wrapper_new = _evaluate!!(model, wrapper, context) - return result, setlogp!!(wrapper_new.varinfo, getlogp(wrapper_new)) + result, wrapper_new = _evaluate!!(model, wrapper) + # TODO(penelopeysm): If seems that if you pass a TSVI to this method, it + # will return the underlying VI, which is a bit counterintuitive (because + # calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it + # again). + return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new)) end """ - _evaluate!!(model::Model, varinfo, context) + _evaluate!!(model::Model, varinfo) -Evaluate the `model` with the arguments matching the given `context` and `varinfo` object. +Evaluate the `model` with the given `varinfo`. + +This function does not wrap the varinfo in a `ThreadSafeVarInfo`. It also does not +reset the log probability of the `varinfo` before running. """ -function _evaluate!!(model::Model, varinfo::AbstractVarInfo, context::AbstractContext) - args, kwargs = make_evaluate_args_and_kwargs(model, varinfo, context) +function _evaluate!!(model::Model, varinfo::AbstractVarInfo) + args, kwargs = make_evaluate_args_and_kwargs(model, varinfo) return model.f(args...; kwargs...) end is_splat_symbol(s::Symbol) = startswith(string(s), "#splat#") """ - make_evaluate_args_and_kwargs(model, varinfo, context) + make_evaluate_args_and_kwargs(model, varinfo) Return the arguments and keyword arguments to be passed to the evaluator of the model, i.e. `model.f`e. """ @generated function make_evaluate_args_and_kwargs( - model::Model{_F,argnames}, varinfo::AbstractVarInfo, context::AbstractContext + model::Model{_F,argnames}, varinfo::AbstractVarInfo ) where {_F,argnames} unwrap_args = [ if is_splat_symbol(var) @@ -930,18 +938,7 @@ Return the arguments and keyword arguments to be passed to the evaluator of the :($matchingvalue(varinfo, model.args.$var)) end for var in argnames ] - - # We want to give `context` precedence over `model.context` while also - # preserving the leaf context of `context`. We can do this by - # 1. Set the leaf context of `model.context` to `leafcontext(context)`. - # 2. Set leaf context of `context` to the context resulting from (1). - # The result is: - # `context` -> `childcontext(context)` -> ... -> `model.context` - # -> `childcontext(model.context)` -> ... -> `leafcontext(context)` return quote - context_new = setleafcontext( - context, setleafcontext(model.context, leafcontext(context)) - ) args = ( model, # Maybe perform `invlink!!` once prior to evaluation to avoid @@ -949,7 +946,6 @@ Return the arguments and keyword arguments to be passed to the evaluator of the # speeding up computation. See docs for `maybe_invlink_before_eval!!` # for more information. maybe_invlink_before_eval!!(varinfo, model), - context_new, $(unwrap_args...), ) kwargs = model.defaults @@ -985,15 +981,7 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f) Generate a sample of type `T` from the prior distribution of the `model`. """ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} - x = last( - evaluate!!( - model, - SimpleVarInfo{Float64}(OrderedDict()), - # NOTE: Use `leafcontext` here so we a) avoid overriding the leaf context of `model`, - # and b) avoid double-stacking the parent contexts. - SamplingContext(rng, SampleFromPrior(), leafcontext(model.context)), - ), - ) + x = last(evaluate_and_sample!!(rng, model, SimpleVarInfo{Float64}(OrderedDict()))) return values_as(x, T) end @@ -1010,7 +998,7 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logprior`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - return getlogp(last(evaluate!!(model, varinfo, DefaultContext()))) + return getlogjoint(last(evaluate!!(model, varinfo))) end """ @@ -1057,7 +1045,14 @@ Return the log prior probability of variables `varinfo` for the probabilistic `m See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) - return getlogp(last(evaluate!!(model, varinfo, PriorContext()))) + # Remove other accumulators from varinfo, since they are unnecessary. + logprioracc = if hasacc(varinfo, Val(:LogPrior)) + getacc(varinfo, Val(:LogPrior)) + else + LogPriorAccumulator() + end + varinfo = setaccs!!(deepcopy(varinfo), (logprioracc,)) + return getlogprior(last(evaluate!!(model, varinfo))) end """ @@ -1104,7 +1099,14 @@ Return the log likelihood of variables `varinfo` for the probabilistic `model`. See also [`logjoint`](@ref) and [`logprior`](@ref). """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) - return getlogp(last(evaluate!!(model, varinfo, LikelihoodContext()))) + # Remove other accumulators from varinfo, since they are unnecessary. + loglikelihoodacc = if hasacc(varinfo, Val(:LogLikelihood)) + getacc(varinfo, Val(:LogLikelihood)) + else + LogLikelihoodAccumulator() + end + varinfo = setaccs!!(deepcopy(varinfo), (loglikelihoodacc,)) + return getloglikelihood(last(evaluate!!(model, varinfo))) end """ @@ -1144,7 +1146,7 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC end """ - predict([rng::AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo}) + predict([rng::Random.AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo}) Generate samples from the posterior predictive distribution by evaluating `model` at each set of parameter values provided in `chain`. The number of posterior predictive samples matches @@ -1158,7 +1160,7 @@ function predict( return map(chain) do params_varinfo vi = deepcopy(varinfo) DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple)) - model(rng, vi, SampleFromPrior()) + model(rng, vi) return vi end end @@ -1206,243 +1208,3 @@ end function returned(model::Model, values, keys) return returned(model, NamedTuple{keys}(values)) end - -""" - is_rhs_model(x) - -Return `true` if `x` is a model or model wrapper, and `false` otherwise. -""" -is_rhs_model(x) = false - -""" - Distributional - -Abstract type for type indicating that something is "distributional". -""" -abstract type Distributional end - -""" - should_auto_prefix(distributional) - -Return `true` if the `distributional` should use automatic prefixing, and `false` otherwise. -""" -function should_auto_prefix end - -""" - is_rhs_model(x) - -Return `true` if the `distributional` is a model, and `false` otherwise. -""" -function is_rhs_model end - -""" - Sampleable{M} <: Distributional - -A wrapper around a model indicating it is sampleable. -""" -struct Sampleable{M,AutoPrefix} <: Distributional - model::M -end - -should_auto_prefix(::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix} = AutoPrefix -is_rhs_model(x::Sampleable) = is_rhs_model(x.model) - -# TODO: Export this if it end up having a purpose beyond `to_submodel`. -""" - to_sampleable(model[, auto_prefix]) - -Return a wrapper around `model` indicating it is sampleable. - -# Arguments -- `model::Model`: the model to wrap. -- `auto_prefix::Bool`: whether to prefix the variables in the model. Default: `true`. -""" -to_sampleable(model, auto_prefix::Bool=true) = Sampleable{typeof(model),auto_prefix}(model) - -""" - rand_like!!(model_wrap, context, varinfo) - -Returns a tuple with the first element being the realization and the second the updated varinfo. - -# Arguments -- `model_wrap::ReturnedModelWrapper`: the wrapper of the model to use. -- `context::AbstractContext`: the context to use for evaluation. -- `varinfo::AbstractVarInfo`: the varinfo to use for evaluation. - """ -function rand_like!!( - model_wrap::Sampleable, context::AbstractContext, varinfo::AbstractVarInfo -) - return rand_like!!(model_wrap.model, context, varinfo) -end - -""" - ReturnedModelWrapper - -A wrapper around a model indicating it is a model over its return values. - -This should rarely be constructed explicitly; see [`returned(model)`](@ref) instead. -""" -struct ReturnedModelWrapper{M<:Model} - model::M -end - -is_rhs_model(::ReturnedModelWrapper) = true - -function rand_like!!( - model_wrap::ReturnedModelWrapper, context::AbstractContext, varinfo::AbstractVarInfo -) - # Return's the value and the (possibly mutated) varinfo. - return _evaluate!!(model_wrap.model, varinfo, context) -end - -""" - returned(model) - -Return a `model` wrapper indicating that it is a model over its return-values. -""" -returned(model::Model) = ReturnedModelWrapper(model) - -""" - to_submodel(model::Model[, auto_prefix::Bool]) - -Return a model wrapper indicating that it is a sampleable model over the return-values. - -This is mainly meant to be used on the right-hand side of a `~` operator to indicate that -the model can be sampled from but not necessarily evaluated for its log density. - -!!! warning - Note that some other operations that one typically associate with expressions of the form - `left ~ right` such as [`condition`](@ref), will also not work with `to_submodel`. - -!!! warning - To avoid variable names clashing between models, it is recommend leave argument `auto_prefix` equal to `true`. - If one does not use automatic prefixing, then it's recommended to use [`prefix(::Model, input)`](@ref) explicitly. - -# Arguments -- `model::Model`: the model to wrap. -- `auto_prefix::Bool`: whether to automatically prefix the variables in the model using the left-hand - side of the `~` statement. Default: `true`. - -# Examples - -## Simple example -```jldoctest submodel-to_submodel; setup=:(using Distributions) -julia> @model function demo1(x) - x ~ Normal() - return 1 + abs(x) - end; - -julia> @model function demo2(x, y) - a ~ to_submodel(demo1(x)) - return y ~ Uniform(0, a) - end; -``` - -When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled: -```jldoctest submodel-to_submodel -julia> vi = VarInfo(demo2(missing, 0.4)); - -julia> @varname(a.x) in keys(vi) -true -``` - -The variable `a` is not tracked. However, it will be assigned the return value of `demo1`, -and can be used in subsequent lines of the model, as shown above. -```jldoctest submodel-to_submodel -julia> @varname(a) in keys(vi) -false -``` - -We can check that the log joint probability of the model accumulated in `vi` is correct: - -```jldoctest submodel-to_submodel -julia> x = vi[@varname(a.x)]; - -julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) -true -``` - -## Without automatic prefixing -As mentioned earlier, by default, the `auto_prefix` argument specifies whether to automatically -prefix the variables in the submodel. If `auto_prefix=false`, then the variables in the submodel -will not be prefixed. -```jldoctest submodel-to_submodel-prefix; setup=:(using Distributions) -julia> @model function demo1(x) - x ~ Normal() - return 1 + abs(x) - end; - -julia> @model function demo2_no_prefix(x, z) - a ~ to_submodel(demo1(x), false) - return z ~ Uniform(-a, 1) - end; - -julia> vi = VarInfo(demo2_no_prefix(missing, 0.4)); - -julia> @varname(x) in keys(vi) # here we just use `x` instead of `a.x` -true -``` -However, not using prefixing is generally not recommended as it can lead to variable name clashes -unless one is careful. For example, if we're re-using the same model twice in a model, not using prefixing -will lead to variable name clashes: However, one can manually prefix using the [`prefix(::Model, input)`](@ref): -```jldoctest submodel-to_submodel-prefix -julia> @model function demo2(x, y, z) - a ~ to_submodel(prefix(demo1(x), :sub1), false) - b ~ to_submodel(prefix(demo1(y), :sub2), false) - return z ~ Uniform(-a, b) - end; - -julia> vi = VarInfo(demo2(missing, missing, 0.4)); - -julia> @varname(sub1.x) in keys(vi) -true - -julia> @varname(sub2.x) in keys(vi) -true -``` - -Variables `a` and `b` are not tracked, but are assigned the return values of the respective -calls to `demo1`: -```jldoctest submodel-to_submodel-prefix -julia> @varname(a) in keys(vi) -false - -julia> @varname(b) in keys(vi) -false -``` - -We can check that the log joint probability of the model accumulated in `vi` is correct: - -```jldoctest submodel-to_submodel-prefix -julia> sub1_x = vi[@varname(sub1.x)]; - -julia> sub2_x = vi[@varname(sub2.x)]; - -julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); - -julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4); - -julia> getlogp(vi) ≈ logprior + loglikelihood -true -``` - -## Usage as likelihood is illegal - -Note that it is illegal to use a `to_submodel` model as a likelihood in another model: - -```jldoctest submodel-to_submodel-illegal; setup=:(using Distributions) -julia> @model inner() = x ~ Normal() -inner (generic function with 2 methods) - -julia> @model illegal_likelihood() = a ~ to_submodel(inner()) -illegal_likelihood (generic function with 2 methods) - -julia> model = illegal_likelihood() | (a = 1.0,); - -julia> model() -ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported -[...] -``` -""" -to_submodel(model::Model, auto_prefix::Bool=true) = - to_sampleable(returned(model), auto_prefix) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index cb9ea4894..59cc5e1bb 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -1,142 +1,117 @@ -# Context version -struct PointwiseLogdensityContext{A,Ctx} <: AbstractContext - logdensities::A - context::Ctx -end +""" + PointwiseLogProbAccumulator{whichlogprob,KeyType,D<:AbstractDict{KeyType}} <: AbstractAccumulator -function PointwiseLogdensityContext( - likelihoods=OrderedDict{VarName,Vector{Float64}}(), - context::AbstractContext=DefaultContext(), -) - return PointwiseLogdensityContext{typeof(likelihoods),typeof(context)}( - likelihoods, context - ) -end +An accumulator that stores the log-probabilities of each variable in a model. -NodeTrait(::PointwiseLogdensityContext) = IsParent() -childcontext(context::PointwiseLogdensityContext) = context.context -function setchildcontext(context::PointwiseLogdensityContext, child) - return PointwiseLogdensityContext(context.logdensities, child) -end +Internally this accumulator stores the log-probabilities in a dictionary, where +the keys are the variable names and the values are vectors of +log-probabilities. Each element in a vector corresponds to one execution of the +model. -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{VarName,Vector{Float64}}}, - vn::VarName, - logp::Real, -) - lookup = context.logdensities - ℓ = get!(lookup, vn, Float64[]) - return push!(ℓ, logp) +`whichlogprob` is a symbol that can be `:both`, `:prior`, or `:likelihood`, and specifies +which log-probabilities to store in the accumulator. `KeyType` is the type by which variable +names are stored, and should be `String` or `VarName`. `D` is the type of the dictionary +used internally to store the log-probabilities, by default +`OrderedDict{KeyType, Vector{LogProbType}}`. +""" +struct PointwiseLogProbAccumulator{whichlogprob,KeyType,D<:AbstractDict{KeyType}} <: + AbstractAccumulator + logps::D end -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{VarName,Float64}}, - vn::VarName, - logp::Real, -) - return context.logdensities[vn] = logp +function PointwiseLogProbAccumulator{whichlogprob}(logps) where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob,keytype(logps),typeof(logps)}(logps) end -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{String,Vector{Float64}}}, - vn::VarName, - logp::Real, -) - lookup = context.logdensities - ℓ = get!(lookup, string(vn), Float64[]) - return push!(ℓ, logp) +function PointwiseLogProbAccumulator{whichlogprob}() where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob,VarName}() end -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{String,Float64}}, - vn::VarName, - logp::Real, -) - return context.logdensities[string(vn)] = logp +function PointwiseLogProbAccumulator{whichlogprob,KeyType}() where {whichlogprob,KeyType} + logps = OrderedDict{KeyType,Vector{LogProbType}}() + return PointwiseLogProbAccumulator{whichlogprob,KeyType,typeof(logps)}(logps) end -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{String,Vector{Float64}}}, - vn::String, - logp::Real, -) - lookup = context.logdensities - ℓ = get!(lookup, vn, Float64[]) - return push!(ℓ, logp) +function Base.push!(acc::PointwiseLogProbAccumulator, vn, logp) + logps = acc.logps + # The last(fieldtypes(eltype(...))) gets the type of the values, rather than the keys. + T = last(fieldtypes(eltype(logps))) + logpvec = get!(logps, vn, T()) + return push!(logpvec, logp) end function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{String,Float64}}, - vn::String, - logp::Real, -) - return context.logdensities[vn] = logp + acc::PointwiseLogProbAccumulator{whichlogprob,String}, vn::VarName, logp +) where {whichlogprob} + return push!(acc, string(vn), logp) end -function _include_prior(context::PointwiseLogdensityContext) - return leafcontext(context) isa Union{PriorContext,DefaultContext} -end -function _include_likelihood(context::PointwiseLogdensityContext) - return leafcontext(context) isa Union{LikelihoodContext,DefaultContext} +function accumulator_name( + ::Type{<:PointwiseLogProbAccumulator{whichlogprob}} +) where {whichlogprob} + return Symbol("PointwiseLogProbAccumulator{$whichlogprob}") end -function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi) - # Defer literal `observe` to child-context. - return tilde_observe!!(context.context, right, left, vi) +function split(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob}(empty(acc.logps)) end -function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi) - # Completely defer to child context if we are not tracking likelihoods. - if !(_include_likelihood(context)) - return tilde_observe!!(context.context, right, left, vn, vi) - end - - # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. - # we have to intercept the call to `tilde_observe!`. - logp, vi = tilde_observe(context.context, right, left, vi) - # Track loglikelihood value. - push!(context, vn, logp) - - return left, acclogp!!(vi, logp) +function combine( + acc::PointwiseLogProbAccumulator{whichlogprob}, + acc2::PointwiseLogProbAccumulator{whichlogprob}, +) where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob}(mergewith(vcat, acc.logps, acc2.logps)) end -# Note on submodels (penelopeysm) -# -# We don't need to overload tilde_observe!! for Sampleables (yet), because it -# is currently not possible to evaluate a model with a Sampleable on the RHS -# of an observe statement. -# -# Note that calling tilde_assume!! on a Sampleable does not necessarily imply -# that there are no observe statements inside the Sampleable. There could well -# be likelihood terms in there, which must be included in the returned logp. -# See e.g. the `demo_dot_assume_observe_submodel` demo model. -# -# This is handled by passing the same context to rand_like!!, which figures out -# which terms to include using the context, and also mutates the context and vi -# appropriately. Thus, we don't need to check against _include_prior(context) -# here. -function tilde_assume!!(context::PointwiseLogdensityContext, right::Sampleable, vn, vi) - value, vi = DynamicPPL.rand_like!!(right, context, vi) - return value, vi +function accumulate_assume!!( + acc::PointwiseLogProbAccumulator{whichlogprob}, val, logjac, vn, right +) where {whichlogprob} + if whichlogprob == :both || whichlogprob == :prior + # T is the element type of the vectors that are the values of `acc.logps`. Usually + # it's LogProbType. + T = eltype(last(fieldtypes(eltype(acc.logps)))) + subacc = accumulate_assume!!(LogPriorAccumulator{T}(), val, logjac, vn, right) + push!(acc, vn, subacc.logp) + end + return acc end -function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi) - !_include_prior(context) && return (tilde_assume!!(context.context, right, vn, vi)) - value, logp, vi = tilde_assume(context.context, right, vn, vi) - # Track loglikelihood value. - push!(context, vn, logp) - return value, acclogp!!(vi, logp) +function accumulate_observe!!( + acc::PointwiseLogProbAccumulator{whichlogprob}, right, left, vn +) where {whichlogprob} + # If `vn` is nothing the LHS of ~ is a literal and we don't have a name to attach this + # acc to, and thus do nothing. + if vn === nothing + return acc + end + if whichlogprob == :both || whichlogprob == :likelihood + # T is the element type of the vectors that are the values of `acc.logps`. Usually + # it's LogProbType. + T = eltype(last(fieldtypes(eltype(acc.logps)))) + subacc = accumulate_observe!!(LogLikelihoodAccumulator{T}(), right, left, vn) + push!(acc, vn, subacc.logp) + end + return acc end """ - pointwise_logdensities(model::Model, chain::Chains, keytype = String) + pointwise_logdensities( + model::Model, + chain::Chains, + keytype=String, + ::Val{whichlogprob}=Val(:both), + ) Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` with keys corresponding to symbols of the variables, and values being matrices of shape `(num_chains, num_samples)`. `keytype` specifies what the type of the keys used in the returned `OrderedDict` are. -Currently, only `String` and `VarName` are supported. +Currently, only `String` and `VarName` are supported. `whichlogprob` specifies +which log-probabilities to compute. It can be `:both`, `:prior`, or +`:likelihood`. + +See also: [`pointwise_loglikelihoods`](@ref), [`pointwise_loglikelihoods`](@ref). # Notes Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ` @@ -234,14 +209,15 @@ julia> m = demo([1.0; 1.0]); julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) (-1.4189385332046727, -1.4189385332046727) ``` - """ function pointwise_logdensities( - model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext() -) where {T} + model::Model, chain, ::Type{KeyType}=String, ::Val{whichlogprob}=Val(:both) +) where {KeyType,whichlogprob} # Get the data by executing the model once vi = VarInfo(model) - point_context = PointwiseLogdensityContext(OrderedDict{T,Vector{Float64}}(), context) + + AccType = PointwiseLogProbAccumulator{whichlogprob,KeyType} + vi = setaccs!!(vi, (AccType(),)) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) for (sample_idx, chain_idx) in iters @@ -249,83 +225,59 @@ function pointwise_logdensities( setval!(vi, chain, sample_idx, chain_idx) # Execute model - model(vi, point_context) + vi = last(evaluate!!(model, vi)) end + logps = getacc(vi, Val(accumulator_name(AccType))).logps niters = size(chain, 1) nchains = size(chain, 3) logdensities = OrderedDict( - varname => reshape(logliks, niters, nchains) for - (varname, logliks) in point_context.logdensities + varname => reshape(vals, niters, nchains) for (varname, vals) in logps ) return logdensities end function pointwise_logdensities( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext() -) - point_context = PointwiseLogdensityContext( - OrderedDict{VarName,Vector{Float64}}(), context - ) - model(varinfo, point_context) - return point_context.logdensities + model::Model, varinfo::AbstractVarInfo, ::Val{whichlogprob}=Val(:both) +) where {whichlogprob} + AccType = PointwiseLogProbAccumulator{whichlogprob} + varinfo = setaccs!!(varinfo, (AccType(),)) + varinfo = last(evaluate!!(model, varinfo)) + return getacc(varinfo, Val(accumulator_name(AccType))).logps end """ - pointwise_loglikelihoods(model, chain[, keytype, context]) + pointwise_loglikelihoods(model, chain[, keytype]) Compute the pointwise log-likelihoods of the model given the chain. -This is the same as `pointwise_logdensities(model, chain, context)`, but only +This is the same as `pointwise_logdensities(model, chain)`, but only including the likelihood terms. -See also: [`pointwise_logdensities`](@ref). -""" -function pointwise_loglikelihoods( - model::Model, - chain, - keytype::Type{T}=String, - context::AbstractContext=LikelihoodContext(), -) where {T} - if !(leafcontext(context) isa LikelihoodContext) - throw(ArgumentError("Leaf context should be a LikelihoodContext")) - end - return pointwise_logdensities(model, chain, T, context) +See also: [`pointwise_logdensities`](@ref), [`pointwise_prior_logdensities`](@ref). +""" +function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T} + return pointwise_logdensities(model, chain, T, Val(:likelihood)) end -function pointwise_loglikelihoods( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext=LikelihoodContext() -) - if !(leafcontext(context) isa LikelihoodContext) - throw(ArgumentError("Leaf context should be a LikelihoodContext")) - end - - return pointwise_logdensities(model, varinfo, context) +function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) + return pointwise_logdensities(model, varinfo, Val(:likelihood)) end """ - pointwise_prior_logdensities(model, chain[, keytype, context]) + pointwise_prior_logdensities(model, chain[, keytype]) Compute the pointwise log-prior-densities of the model given the chain. -This is the same as `pointwise_logdensities(model, chain, context)`, but only +This is the same as `pointwise_logdensities(model, chain)`, but only including the prior terms. -See also: [`pointwise_logdensities`](@ref). + +See also: [`pointwise_logdensities`](@ref), [`pointwise_loglikelihoods`](@ref). """ function pointwise_prior_logdensities( - model::Model, chain, keytype::Type{T}=String, context::AbstractContext=PriorContext() + model::Model, chain, keytype::Type{T}=String ) where {T} - if !(leafcontext(context) isa PriorContext) - throw(ArgumentError("Leaf context should be a PriorContext")) - end - - return pointwise_logdensities(model, chain, T, context) + return pointwise_logdensities(model, chain, T, Val(:prior)) end -function pointwise_prior_logdensities( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext=PriorContext() -) - if !(leafcontext(context) isa PriorContext) - throw(ArgumentError("Leaf context should be a PriorContext")) - end - - return pointwise_logdensities(model, varinfo, context) +function pointwise_prior_logdensities(model::Model, varinfo::AbstractVarInfo) + return pointwise_logdensities(model, varinfo, Val(:prior)) end diff --git a/src/sampler.jl b/src/sampler.jl index 49d910fec..673b5128f 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -58,12 +58,12 @@ function AbstractMCMC.step( kwargs..., ) vi = VarInfo() - model(rng, vi, sampler) + DynamicPPL.evaluate_and_sample!!(rng, model, vi, sampler) return vi, nothing end """ - default_varinfo(rng, model, sampler[, context]) + default_varinfo(rng, model, sampler) Return a default varinfo object for the given `model` and `sampler`. @@ -71,22 +71,13 @@ Return a default varinfo object for the given `model` and `sampler`. - `rng::Random.AbstractRNG`: Random number generator. - `model::Model`: Model for which we want to create a varinfo object. - `sampler::AbstractSampler`: Sampler which will make use of the varinfo object. -- `context::AbstractContext`: Context in which the model is evaluated. # Returns - `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`. """ function default_varinfo(rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler) - return default_varinfo(rng, model, sampler, DefaultContext()) -end -function default_varinfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler, - context::AbstractContext, -) init_sampler = initialsampler(sampler) - return typed_varinfo(rng, model, init_sampler, context) + return typed_varinfo(rng, model, init_sampler) end function AbstractMCMC.sample( @@ -119,7 +110,7 @@ function AbstractMCMC.step( # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 # and https://github.com/TuringLang/Turing.jl/issues/1563 # to avoid that existing variables are resampled - vi = last(evaluate!!(model, vi, DefaultContext())) + vi = last(evaluate!!(model, vi)) end return initialstep(rng, model, spl, vi; initial_params, kwargs...) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 2297bc9e1..ddc3275ae 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -36,13 +36,10 @@ julia> m = demo(); julia> rng = StableRNG(42); -julia> ### Sampling ### - ctx = SamplingContext(rng, SampleFromPrior(), DefaultContext()); - julia> # In the `NamedTuple` version we need to provide the place-holder values for # the variables which are using "containers", e.g. `Array`. # In this case, this means that we need to specify `x` but not `m`. - _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo((x = ones(2), )), ctx); + _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo((x = ones(2), ))); julia> # (✓) Vroom, vroom! FAST!!! vi[@varname(x[1])] @@ -60,12 +57,12 @@ julia> vi[@varname(x[1:2])] 1.3736306979834252 julia> # (×) If we don't provide the container... - _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo(), ctx); vi + _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo()); vi ERROR: type NamedTuple has no field x [...] julia> # If one does not know the varnames, we can use a `OrderedDict` instead. - _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo{Float64}(OrderedDict()), ctx); + _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict())); julia> # (✓) Sort of fast, but only possible at runtime. vi[@varname(x[1])] @@ -94,28 +91,28 @@ demo_constrained (generic function with 2 methods) julia> m = demo_constrained(); -julia> _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo(), ctx); +julia> _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo()); julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞ 1.8632965762164932 -julia> _, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(), true), ctx); +julia> _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ -0.21080155351918753 -julia> xs = [last(DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(), true), ctx))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true julia> # And with `OrderedDict` of course! - _, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict()), true), ctx); + _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict()), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ 0.6225185067787314 -julia> xs = [last(DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(), true), ctx))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true @@ -125,18 +122,18 @@ Evaluation in transformed space of course also works: ```jldoctest simplevarinfo-general julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true) -Transformed SimpleVarInfo((x = -1.0,), 0.0) +Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0))) julia> # (✓) Positive probability mass on negative numbers! - getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) + getlogjoint(last(DynamicPPL.evaluate!!(m, vi))) -1.3678794411714423 julia> # While if we forget to indicate that it's transformed: vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false) -SimpleVarInfo((x = -1.0,), 0.0) +SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0))) julia> # (✓) No probability mass on negative numbers! - getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) + getlogjoint(last(DynamicPPL.evaluate!!(m, vi))) -Inf ``` @@ -188,41 +185,37 @@ ERROR: type NamedTuple has no field b [...] ``` """ -struct SimpleVarInfo{NT,T,C<:AbstractTransformation} <: AbstractVarInfo +struct SimpleVarInfo{NT,Accs<:AccumulatorTuple where {N},C<:AbstractTransformation} <: + AbstractVarInfo "underlying representation of the realization represented" values::NT - "holds the accumulated log-probability" - logp::T + "tuple of accumulators for things like log prior and log likelihood" + accs::Accs "represents whether it assumes variables to be transformed" transformation::C end transformation(vi::SimpleVarInfo) = vi.transformation -# Makes things a bit more readable vs. putting `Float64` everywhere. -const SIMPLEVARINFO_DEFAULT_ELTYPE = Float64 - -function SimpleVarInfo{NT,T}(values, logp) where {NT,T} - return SimpleVarInfo{NT,T,NoTransformation}(values, logp, NoTransformation()) +function SimpleVarInfo(values, accs) + return SimpleVarInfo(values, accs, NoTransformation()) end -function SimpleVarInfo{T}(θ) where {T<:Real} - return SimpleVarInfo{typeof(θ),T}(θ, zero(T)) +function SimpleVarInfo{T}(values) where {T<:Real} + return SimpleVarInfo(values, default_accumulators(T)) end - -# Constructors without type-specification. -SimpleVarInfo(θ) = SimpleVarInfo{SIMPLEVARINFO_DEFAULT_ELTYPE}(θ) -function SimpleVarInfo(θ::Union{<:NamedTuple,<:AbstractDict}) - return if isempty(θ) +function SimpleVarInfo(values) + return SimpleVarInfo{LogProbType}(values) +end +function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict}) + return if isempty(values) # Can't infer from values, so we just use default. - SimpleVarInfo{SIMPLEVARINFO_DEFAULT_ELTYPE}(θ) + SimpleVarInfo{LogProbType}(values) else # Infer from `values`. - SimpleVarInfo{float_type_with_fallback(infer_nested_eltype(typeof(θ)))}(θ) + SimpleVarInfo{float_type_with_fallback(infer_nested_eltype(typeof(values)))}(values) end end -SimpleVarInfo(values, logp) = SimpleVarInfo{typeof(values),typeof(logp)}(values, logp) - # Using `kwargs` to specify the values. function SimpleVarInfo{T}(; kwargs...) where {T<:Real} return SimpleVarInfo{T}(NamedTuple(kwargs)) @@ -232,45 +225,59 @@ function SimpleVarInfo(; kwargs...) end # Constructor from `Model`. -function SimpleVarInfo( - model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}... -) - return SimpleVarInfo{Float64}(model, args...) +function SimpleVarInfo{T}( + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() +) where {T<:Real} + new_model = contextualize(model, SamplingContext(rng, sampler, model.context)) + return last(evaluate!!(new_model, SimpleVarInfo{T}())) end function SimpleVarInfo{T}( - model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}... + model::Model, sampler::AbstractSampler=SampleFromPrior() ) where {T<:Real} - return last(evaluate!!(model, SimpleVarInfo{T}(), args...)) + return SimpleVarInfo{T}(Random.default_rng(), model, sampler) +end +# Constructors without type param +function SimpleVarInfo( + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() +) + return SimpleVarInfo{LogProbType}(rng, model, sampler) +end +function SimpleVarInfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) + return SimpleVarInfo{LogProbType}(Random.default_rng(), model, sampler) end # Constructor from `VarInfo`. -function SimpleVarInfo(vi::NTVarInfo, (::Type{D})=NamedTuple; kwargs...) where {D} - return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...) +function SimpleVarInfo(vi::NTVarInfo, ::Type{D}) where {D} + values = values_as(vi, D) + return SimpleVarInfo(values, deepcopy(getaccs(vi))) end -function SimpleVarInfo{T}( - vi::VarInfo{<:NamedTuple{names}}, ::Type{D} -) where {T<:Real,names,D} +function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D} values = values_as(vi, D) - return SimpleVarInfo(values, convert(T, getlogp(vi))) + accs = map(acc -> convert_eltype(T, acc), getaccs(vi)) + return SimpleVarInfo(values, accs) end function untyped_simple_varinfo(model::Model) varinfo = SimpleVarInfo(OrderedDict()) - return last(evaluate!!(model, varinfo, SamplingContext())) + return last(evaluate_and_sample!!(model, varinfo)) end function typed_simple_varinfo(model::Model) varinfo = SimpleVarInfo{Float64}() - return last(evaluate!!(model, varinfo, SamplingContext())) + return last(evaluate_and_sample!!(model, varinfo)) end function unflatten(svi::SimpleVarInfo, x::AbstractVector) - logp = getlogp(svi) vals = unflatten(svi.values, x) - T = eltype(x) - return SimpleVarInfo{typeof(vals),T,typeof(svi.transformation)}( - vals, T(logp), svi.transformation + # TODO(mhauru) See comment in unflatten in src/varinfo.jl for why this conversion is + # required but undesireable. + # The below line is finicky for type stability. For instance, assigning the eltype to + # convert to into an intermediate variable makes this unstable (constant propagation) + # fails. Take care when editing. + accs = map( + acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), getaccs(svi) ) + return SimpleVarInfo(vals, accs, svi.transformation) end function BangBang.empty!!(vi::SimpleVarInfo) @@ -278,21 +285,8 @@ function BangBang.empty!!(vi::SimpleVarInfo) end Base.isempty(vi::SimpleVarInfo) = isempty(vi.values) -getlogp(vi::SimpleVarInfo) = vi.logp -getlogp(vi::SimpleVarInfo{<:Any,<:Ref}) = vi.logp[] - -setlogp!!(vi::SimpleVarInfo, logp) = Accessors.@set vi.logp = logp -acclogp!!(vi::SimpleVarInfo, logp) = Accessors.@set vi.logp = getlogp(vi) + logp - -function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) - vi.logp[] = logp - return vi -end - -function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) - vi.logp[] += logp - return vi -end +getaccs(vi::SimpleVarInfo) = vi.accs +setaccs!!(vi::SimpleVarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs """ keys(vi::SimpleVarInfo) @@ -302,12 +296,12 @@ Return an iterator of keys present in `vi`. Base.keys(vi::SimpleVarInfo) = keys(vi.values) Base.keys(vi::SimpleVarInfo{<:NamedTuple}) = map(k -> VarName{k}(), keys(vi.values)) -function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo) +function Base.show(io::IO, mime::MIME"text/plain", svi::SimpleVarInfo) if !(svi.transformation isa NoTransformation) print(io, "Transformed ") end - return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ")") + return print(io, "SimpleVarInfo(", svi.values, ", ", repr(mime, getaccs(svi)), ")") end function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution) @@ -454,11 +448,11 @@ _subset(x::VarNamedVector, vns) = subset(x, vns) # `merge` function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) values = merge(varinfo_left.values, varinfo_right.values) - logp = getlogp(varinfo_right) + accs = deepcopy(getaccs(varinfo_right)) transformation = merge_transformations( varinfo_left.transformation, varinfo_right.transformation ) - return SimpleVarInfo(values, logp, transformation) + return SimpleVarInfo(values, accs, transformation) end # Context implementations @@ -473,9 +467,11 @@ function assume( ) value = init(rng, dist, sampler) # Transform if we're working in unconstrained space. - value_raw = to_maybe_linked_internal(vi, vn, dist, value) + f = to_maybe_linked_internal_transform(vi, vn, dist) + value_raw, logjac = with_logabsdet_jacobian(f, value) vi = BangBang.push!!(vi, vn, value_raw, dist) - return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi + vi = accumulate_assume!!(vi, value, -logjac, vn, dist) + return value, vi end # NOTE: We don't implement `settrans!!(vi, trans, vn)`. @@ -497,8 +493,8 @@ islinked(vi::SimpleVarInfo) = istrans(vi) values_as(vi::SimpleVarInfo) = vi.values values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values -function values_as(vi::SimpleVarInfo{<:Any,T}, ::Type{Vector}) where {T} - isempty(vi) && return T[] +function values_as(vi::SimpleVarInfo, ::Type{Vector}) + isempty(vi) && return Any[] return mapreduce(tovec, vcat, values(vi.values)) end function values_as(vi::SimpleVarInfo, ::Type{D}) where {D<:AbstractDict} @@ -613,12 +609,11 @@ function link!!( vi::SimpleVarInfo{<:NamedTuple}, ::Model, ) - # TODO: Make sure that `spl` is respected. b = inverse(t.bijector) x = vi.values y, logjac = with_logabsdet_jacobian(b, x) - lp_new = getlogp(vi) - logjac - vi_new = setlogp!!(Accessors.@set(vi.values = y), lp_new) + vi_new = Accessors.@set(vi.values = y) + vi_new = acclogprior!!(vi_new, -logjac) return settrans!!(vi_new, t) end @@ -627,12 +622,11 @@ function invlink!!( vi::SimpleVarInfo{<:NamedTuple}, ::Model, ) - # TODO: Make sure that `spl` is respected. b = t.bijector y = vi.values x, logjac = with_logabsdet_jacobian(b, y) - lp_new = getlogp(vi) + logjac - vi_new = setlogp!!(Accessors.@set(vi.values = x), lp_new) + vi_new = Accessors.@set(vi.values = x) + vi_new = acclogprior!!(vi_new, logjac) return settrans!!(vi_new, NoTransformation()) end @@ -645,15 +639,4 @@ function from_linked_internal_transform(vi::SimpleVarInfo, ::VarName, dist) return invlink_transform(dist) end -# Threadsafe stuff. -# For `SimpleVarInfo` we don't really need `Ref` so let's not use it. -function ThreadSafeVarInfo(vi::SimpleVarInfo) - return ThreadSafeVarInfo(vi, zeros(typeof(getlogp(vi)), Threads.nthreads() * 2)) -end -function ThreadSafeVarInfo(vi::SimpleVarInfo{<:Any,<:Ref}) - return ThreadSafeVarInfo( - vi, [Ref(zero(getlogp(vi))) for _ in 1:(Threads.nthreads() * 2)] - ) -end - has_varnamedvector(vi::SimpleVarInfo) = vi.values isa VarNamedVector diff --git a/src/submodel.jl b/src/submodel.jl new file mode 100644 index 000000000..dcb107bb4 --- /dev/null +++ b/src/submodel.jl @@ -0,0 +1,195 @@ +""" + Submodel{M,AutoPrefix} + +A wrapper around a model, plus a flag indicating whether it should be automatically +prefixed with the left-hand variable in a `~` statement. +""" +struct Submodel{M,AutoPrefix} + model::M +end + +""" + to_submodel(model::Model[, auto_prefix::Bool]) + +Return a model wrapper indicating that it is a sampleable model over the return-values. + +This is mainly meant to be used on the right-hand side of a `~` operator to indicate that +the model can be sampled from but not necessarily evaluated for its log density. + +!!! warning + Note that some other operations that one typically associate with expressions of the form + `left ~ right` such as [`condition`](@ref), will also not work with `to_submodel`. + +!!! warning + To avoid variable names clashing between models, it is recommended to leave the argument `auto_prefix` equal to `true`. + If one does not use automatic prefixing, then it's recommended to use [`prefix(::Model, input)`](@ref) explicitly, i.e. `to_submodel(prefix(model, @varname(my_prefix)))` + +# Arguments +- `model::Model`: the model to wrap. +- `auto_prefix::Bool`: whether to automatically prefix the variables in the model using the left-hand + side of the `~` statement. Default: `true`. + +# Examples + +## Simple example +```jldoctest submodel-to_submodel; setup=:(using Distributions) +julia> @model function demo1(x) + x ~ Normal() + return 1 + abs(x) + end; + +julia> @model function demo2(x, y) + a ~ to_submodel(demo1(x)) + return y ~ Uniform(0, a) + end; +``` + +When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled: +```jldoctest submodel-to_submodel +julia> vi = VarInfo(demo2(missing, 0.4)); + +julia> @varname(a.x) in keys(vi) +true +``` + +The variable `a` is not tracked. However, it will be assigned the return value of `demo1`, +and can be used in subsequent lines of the model, as shown above. +```jldoctest submodel-to_submodel +julia> @varname(a) in keys(vi) +false +``` + +We can check that the log joint probability of the model accumulated in `vi` is correct: + +```jldoctest submodel-to_submodel +julia> x = vi[@varname(a.x)]; + +julia> getlogjoint(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) +true +``` + +## Without automatic prefixing +As mentioned earlier, by default, the `auto_prefix` argument specifies whether to automatically +prefix the variables in the submodel. If `auto_prefix=false`, then the variables in the submodel +will not be prefixed. +```jldoctest submodel-to_submodel-prefix; setup=:(using Distributions) +julia> @model function demo1(x) + x ~ Normal() + return 1 + abs(x) + end; + +julia> @model function demo2_no_prefix(x, z) + a ~ to_submodel(demo1(x), false) + return z ~ Uniform(-a, 1) + end; + +julia> vi = VarInfo(demo2_no_prefix(missing, 0.4)); + +julia> @varname(x) in keys(vi) # here we just use `x` instead of `a.x` +true +``` +However, not using prefixing is generally not recommended as it can lead to variable name clashes +unless one is careful. For example, if we're re-using the same model twice in a model, not using prefixing +will lead to variable name clashes: However, one can manually prefix using the [`prefix(::Model, input)`](@ref): +```jldoctest submodel-to_submodel-prefix +julia> @model function demo2(x, y, z) + a ~ to_submodel(prefix(demo1(x), :sub1), false) + b ~ to_submodel(prefix(demo1(y), :sub2), false) + return z ~ Uniform(-a, b) + end; + +julia> vi = VarInfo(demo2(missing, missing, 0.4)); + +julia> @varname(sub1.x) in keys(vi) +true + +julia> @varname(sub2.x) in keys(vi) +true +``` + +Variables `a` and `b` are not tracked, but are assigned the return values of the respective +calls to `demo1`: +```jldoctest submodel-to_submodel-prefix +julia> @varname(a) in keys(vi) +false + +julia> @varname(b) in keys(vi) +false +``` + +We can check that the log joint probability of the model accumulated in `vi` is correct: + +```jldoctest submodel-to_submodel-prefix +julia> sub1_x = vi[@varname(sub1.x)]; + +julia> sub2_x = vi[@varname(sub2.x)]; + +julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); + +julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4); + +julia> getlogjoint(vi) ≈ logprior + loglikelihood +true +``` + +## Usage as likelihood is illegal + +Note that it is illegal to use a `to_submodel` model as a likelihood in another model: + +```jldoctest submodel-to_submodel-illegal; setup=:(using Distributions) +julia> @model inner() = x ~ Normal() +inner (generic function with 2 methods) + +julia> @model illegal_likelihood() = a ~ to_submodel(inner()) +illegal_likelihood (generic function with 2 methods) + +julia> model = illegal_likelihood() | (a = 1.0,); + +julia> model() +ERROR: ArgumentError: `x ~ to_submodel(...)` is not supported when `x` is observed +[...] +``` +""" +to_submodel(m::Model, auto_prefix::Bool=true) = Submodel{typeof(m),auto_prefix}(m) + +# When automatic prefixing is used, the submodel itself doesn't carry the +# prefix, as the prefix is obtained from the LHS of `~` (whereas the submodel +# is on the RHS). The prefix can only be obtained in `tilde_assume!!`, and then +# passed into this function. +# +# `parent_context` here refers to the context of the model that contains the +# submodel. +function _evaluate!!( + submodel::Submodel{M,AutoPrefix}, + vi::AbstractVarInfo, + parent_context::AbstractContext, + left_vn::VarName, +) where {M<:Model,AutoPrefix} + # First, we construct the context to be used when evaluating the submodel. There + # are several considerations here: + # (1) We need to apply an appropriate PrefixContext when evaluating the submodel, but + # _only_ if automatic prefixing is supposed to be applied. + submodel_context_prefixed = if AutoPrefix + PrefixContext(left_vn, submodel.model.context) + else + submodel.model.context + end + + # (2) We need to respect the leaf-context of the parent model. This, unfortunately, + # means disregarding the leaf-context of the submodel. + submodel_context = setleafcontext( + submodel_context_prefixed, leafcontext(parent_context) + ) + + # (3) We need to use the parent model's context to wrap the whole thing, so that + # e.g. if the user conditions the parent model, the conditioned variables will be + # correctly picked up when evaluating the submodel. + eval_context = setleafcontext(parent_context, submodel_context) + + # (4) Finally, we need to store that context inside the submodel. + model = contextualize(submodel.model, eval_context) + + # Once that's all set up nicely, we can just _evaluate!! the wrapped model. This + # returns a tuple of submodel.model's return value and the new varinfo. + return _evaluate!!(model, vi) +end diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl deleted file mode 100644 index 5f1ec95ec..000000000 --- a/src/submodel_macro.jl +++ /dev/null @@ -1,290 +0,0 @@ -""" - @submodel model - @submodel ... = model - -Run a Turing `model` nested inside of a Turing model. - -!!! warning - This is deprecated and will be removed in a future release. - Use `left ~ to_submodel(model)` instead (see [`to_submodel`](@ref)). - -# Examples - -```jldoctest submodel; setup=:(using Distributions) -julia> @model function demo1(x) - x ~ Normal() - return 1 + abs(x) - end; - -julia> @model function demo2(x, y) - @submodel a = demo1(x) - return y ~ Uniform(0, a) - end; -``` - -When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled: -```jldoctest submodel -julia> vi = VarInfo(demo2(missing, 0.4)); -┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. -│ caller = ip:0x0 -└ @ Core :-1 - -julia> @varname(x) in keys(vi) -true -``` - -Variable `a` is not tracked since it can be computed from the random variable `x` that was -tracked when running `demo1`: -```jldoctest submodel -julia> @varname(a) in keys(vi) -false -``` - -We can check that the log joint probability of the model accumulated in `vi` is correct: - -```jldoctest submodel -julia> x = vi[@varname(x)]; - -julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) -true -``` -""" -macro submodel(expr) - return submodel(:(prefix = false), expr) -end - -""" - @submodel prefix=... model - @submodel prefix=... ... = model - -Run a Turing `model` nested inside of a Turing model and add "`prefix`." as a prefix -to all random variables inside of the `model`. - -Valid expressions for `prefix=...` are: -- `prefix=false`: no prefix is used. -- `prefix=true`: _attempt_ to automatically determine the prefix from the left-hand side - `... = model` by first converting into a `VarName`, and then calling `Symbol` on this. -- `prefix=expression`: results in the prefix `Symbol(expression)`. - -The prefix makes it possible to run the same Turing model multiple times while -keeping track of all random variables correctly. - -!!! warning - This is deprecated and will be removed in a future release. - Use `left ~ to_submodel(model)` instead (see [`to_submodel(model)`](@ref)). - -# Examples -## Example models -```jldoctest submodelprefix; setup=:(using Distributions) -julia> @model function demo1(x) - x ~ Normal() - return 1 + abs(x) - end; - -julia> @model function demo2(x, y, z) - @submodel prefix="sub1" a = demo1(x) - @submodel prefix="sub2" b = demo1(y) - return z ~ Uniform(-a, b) - end; -``` - -When we sample from the model `demo2(missing, missing, 0.4)` random variables `sub1.x` and -`sub2.x` will be sampled: -```jldoctest submodelprefix -julia> vi = VarInfo(demo2(missing, missing, 0.4)); -┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. -│ caller = ip:0x0 -└ @ Core :-1 - -julia> @varname(sub1.x) in keys(vi) -true - -julia> @varname(sub2.x) in keys(vi) -true -``` - -Variables `a` and `b` are not tracked since they can be computed from the random variables `sub1.x` and -`sub2.x` that were tracked when running `demo1`: -```jldoctest submodelprefix -julia> @varname(a) in keys(vi) -false - -julia> @varname(b) in keys(vi) -false -``` - -We can check that the log joint probability of the model accumulated in `vi` is correct: - -```jldoctest submodelprefix -julia> sub1_x = vi[@varname(sub1.x)]; - -julia> sub2_x = vi[@varname(sub2.x)]; - -julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); - -julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4); - -julia> getlogp(vi) ≈ logprior + loglikelihood -true -``` - -## Different ways of setting the prefix -```jldoctest submodel-prefix-alternatives; setup=:(using DynamicPPL, Distributions) -julia> @model inner() = x ~ Normal() -inner (generic function with 2 methods) - -julia> # When `prefix` is unspecified, no prefix is used. - @model submodel_noprefix() = @submodel a = inner() -submodel_noprefix (generic function with 2 methods) - -julia> @varname(x) in keys(VarInfo(submodel_noprefix())) -┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. -│ caller = ip:0x0 -└ @ Core :-1 -true - -julia> # Explicitely don't use any prefix. - @model submodel_prefix_false() = @submodel prefix=false a = inner() -submodel_prefix_false (generic function with 2 methods) - -julia> @varname(x) in keys(VarInfo(submodel_prefix_false())) -┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. -│ caller = ip:0x0 -└ @ Core :-1 -true - -julia> # Automatically determined from `a`. - @model submodel_prefix_true() = @submodel prefix=true a = inner() -submodel_prefix_true (generic function with 2 methods) - -julia> @varname(a.x) in keys(VarInfo(submodel_prefix_true())) -┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. -│ caller = ip:0x0 -└ @ Core :-1 -true - -julia> # Using a static string. - @model submodel_prefix_string() = @submodel prefix="my prefix" a = inner() -submodel_prefix_string (generic function with 2 methods) - -julia> @varname(var"my prefix".x) in keys(VarInfo(submodel_prefix_string())) -┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. -│ caller = ip:0x0 -└ @ Core :-1 -true - -julia> # Using string interpolation. - @model submodel_prefix_interpolation() = @submodel prefix="\$(nameof(inner()))" a = inner() -submodel_prefix_interpolation (generic function with 2 methods) - -julia> @varname(inner.x) in keys(VarInfo(submodel_prefix_interpolation())) -┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. -│ caller = ip:0x0 -└ @ Core :-1 -true - -julia> # Or using some arbitrary expression. - @model submodel_prefix_expr() = @submodel prefix=1 + 2 a = inner() -submodel_prefix_expr (generic function with 2 methods) - -julia> @varname(var"3".x) in keys(VarInfo(submodel_prefix_expr())) -┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. -│ caller = ip:0x0 -└ @ Core :-1 -true - -julia> # (×) Automatic prefixing without a left-hand side expression does not work! - @model submodel_prefix_error() = @submodel prefix=true inner() -ERROR: LoadError: cannot automatically prefix with no left-hand side -[...] -``` - -# Notes -- The choice `prefix=expression` means that the prefixing will incur a runtime cost. - This is also the case for `prefix=true`, depending on whether the expression on the - the right-hand side of `... = model` requires runtime-information or not, e.g. - `x = model` will result in the _static_ prefix `x`, while `x[i] = model` will be - resolved at runtime. -""" -macro submodel(prefix_expr, expr) - return submodel(prefix_expr, expr, esc(:__context__)) -end - -# Automatic prefixing. -function prefix_submodel_context(prefix::Bool, left::Symbol, ctx) - return prefix ? prefix_submodel_context(left, ctx) : ctx -end - -function prefix_submodel_context(prefix::Bool, left::Expr, ctx) - return prefix ? prefix_submodel_context(varname(left), ctx) : ctx -end - -# Manual prefixing. -prefix_submodel_context(prefix, left, ctx) = prefix_submodel_context(prefix, ctx) -function prefix_submodel_context(prefix, ctx) - # E.g. `prefix="asd[$i]"` or `prefix=asd` with `asd` to be evaluated. - return :($(PrefixContext)($(Val)($(Symbol)($(esc(prefix)))), $ctx)) -end - -function prefix_submodel_context(prefix::Union{AbstractString,Symbol}, ctx) - # E.g. `prefix="asd"`. - return :($(PrefixContext)($(esc(Meta.quot(Val(Symbol(prefix))))), $ctx)) -end - -function prefix_submodel_context(prefix::Bool, ctx) - if prefix - error("cannot automatically prefix with no left-hand side") - end - - return ctx -end - -const SUBMODEL_DEPWARN_MSG = "`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax." - -function submodel(prefix_expr, expr, ctx=esc(:__context__)) - prefix_left, prefix = getargs_assignment(prefix_expr) - if prefix_left !== :prefix - error("$(prefix_left) is not a valid kwarg") - end - - # The user expects `@submodel ...` to return the - # return-value of the `...`, hence we need to capture - # the return-value and handle it correctly. - @gensym retval - - # `prefix=false` => don't prefix, i.e. do nothing to `ctx`. - # `prefix=true` => automatically determine prefix. - # `prefix=...` => use it. - args_assign = getargs_assignment(expr) - return if args_assign === nothing - ctx = prefix_submodel_context(prefix, ctx) - quote - # Raise deprecation warning to let user know that we recommend using `left ~ to_submodel(model)`. - $(Base.depwarn)(SUBMODEL_DEPWARN_MSG, Symbol("@submodel")) - - $retval, $(esc(:__varinfo__)) = $(_evaluate!!)( - $(esc(expr)), $(esc(:__varinfo__)), $(ctx) - ) - $retval - end - else - L, R = args_assign - # Now that we have `L` and `R`, we can prefix automagically. - try - ctx = prefix_submodel_context(prefix, L, ctx) - catch e - error( - "failed to determine prefix from $(L); please specify prefix using the `@submodel prefix=\"your prefix\" ...` syntax", - ) - end - quote - # Raise deprecation warning to let user know that we recommend using `left ~ to_submodel(model)`. - $(Base.depwarn)(SUBMODEL_DEPWARN_MSG, Symbol("@submodel")) - - $retval, $(esc(:__varinfo__)) = $(_evaluate!!)( - $(esc(R)), $(esc(:__varinfo__)), $(ctx) - ) - $(esc(L)) = $retval - end - end -end diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index 0c267c1c5..155f3b68d 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -4,28 +4,57 @@ using ADTypes: AbstractADType, AutoForwardDiff using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions -using DynamicPPL: - Model, - LogDensityFunction, - VarInfo, - AbstractVarInfo, - link, - DefaultContext, - AbstractContext +using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link using LogDensityProblems: logdensity, logdensity_and_gradient -using Random: Random, Xoshiro +using Random: AbstractRNG, default_rng using Statistics: median using Test: @test -export ADResult, run_ad, ADIncorrectException +export ADResult, run_ad, ADIncorrectException, WithBackend, WithExpectedResult, NoTest """ - REFERENCE_ADTYPE + AbstractADCorrectnessTestSetting -Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since -it's the default AD backend used in Turing.jl. +Different ways of testing the correctness of an AD backend. """ -const REFERENCE_ADTYPE = AutoForwardDiff() +abstract type AbstractADCorrectnessTestSetting end + +""" + WithBackend(adtype::AbstractADType=AutoForwardDiff()) <: AbstractADCorrectnessTestSetting + +Test correctness by comparing it against the result obtained with `adtype`. + +`adtype` defaults to ForwardDiff.jl, since it's the default AD backend used in +Turing.jl. +""" +struct WithBackend{AD<:AbstractADType} <: AbstractADCorrectnessTestSetting + adtype::AD +end +WithBackend() = WithBackend(AutoForwardDiff()) + +""" + WithExpectedResult( + value::T, + grad::AbstractVector{T} + ) where {T <: AbstractFloat} + <: AbstractADCorrectnessTestSetting + +Test correctness by comparing it against a known result (e.g. one obtained +analytically, or one obtained with a different backend previously). Both the +value of the primal (i.e. the log-density) as well as its gradient must be +supplied. +""" +struct WithExpectedResult{T<:AbstractFloat} <: AbstractADCorrectnessTestSetting + value::T + grad::AbstractVector{T} +end + +""" + NoTest() <: AbstractADCorrectnessTestSetting + +Disable correctness testing. +""" +struct NoTest <: AbstractADCorrectnessTestSetting end """ ADIncorrectException{T<:AbstractFloat} @@ -45,39 +74,38 @@ struct ADIncorrectException{T<:AbstractFloat} <: Exception end """ - ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat} + ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloat} Data structure to store the results of the AD correctness test. The type parameter `Tparams` is the numeric type of the parameters passed in; -`Tresult` is the type of the value and the gradient. +`Tresult` is the type of the value and the gradient; and `Ttol` is the type of the +absolute and relative tolerances used for correctness testing. # Fields $(TYPEDFIELDS) """ -struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat} +struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloat} "The DynamicPPL model that was tested" model::Model "The VarInfo that was used" varinfo::AbstractVarInfo - "The evaluation context that was used" - context::AbstractContext "The values at which the model was evaluated" params::Vector{Tparams} "The AD backend that was tested" adtype::AbstractADType - "The absolute tolerance for the value of logp" - value_atol::Tresult - "The absolute tolerance for the gradient of logp" - grad_atol::Tresult + "Absolute tolerance used for correctness test" + atol::Ttol + "Relative tolerance used for correctness test" + rtol::Ttol "The expected value of logp" value_expected::Union{Nothing,Tresult} "The expected gradient of logp" grad_expected::Union{Nothing,Vector{Tresult}} "The value of logp (calculated using `adtype`)" - value_actual::Union{Nothing,Tresult} + value_actual::Tresult "The gradient of logp (calculated using `adtype`)" - grad_actual::Union{Nothing,Vector{Tresult}} + grad_actual::Vector{Tresult} "If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself" time_vs_primal::Union{Nothing,Tresult} end @@ -86,15 +114,12 @@ end run_ad( model::Model, adtype::ADTypes.AbstractADType; - test=true, + test::Union{AbstractADCorrectnessTestSetting,Bool}=WithBackend(), benchmark=false, - value_atol=1e-6, - grad_atol=1e-6, + atol::AbstractFloat=1e-8, + rtol::AbstractFloat=sqrt(eps()), varinfo::AbstractVarInfo=link(VarInfo(model), model), params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, - context::AbstractContext=DefaultContext(), - reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, - expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing, verbose=true, )::ADResult @@ -136,8 +161,8 @@ Everything else is optional, and can be categorised into several groups: Note that if the VarInfo is not specified (and thus automatically generated) the parameters in it will have been sampled from the prior of the model. If - you want to seed the parameter generation, the easiest way is to pass a - `rng` argument to the VarInfo constructor (i.e. do `VarInfo(rng, model)`). + you want to seed the parameter generation for the VarInfo, you can pass the + `rng` keyword argument, which will then be used to create the VarInfo. Finally, note that these only reflect the parameters used for _evaluating_ the gradient. If you also want to control the parameters used for @@ -146,33 +171,37 @@ Everything else is optional, and can be categorised into several groups: prep_params)`. You could then evaluate the gradient at a different set of parameters using the `params` keyword argument. -3. _How to specify the evaluation context._ - - A `DynamicPPL.AbstractContext` can be passed as the `context` keyword - argument to control the evaluation context. This defaults to - `DefaultContext()`. - -4. _How to specify the results to compare against._ (Only if `test=true`.) +3. _How to specify the results to compare against._ Once logp and its gradient has been calculated with the specified `adtype`, - it must be tested for correctness. + it can optionally be tested for correctness. The exact way this is tested + is specified in the `test` parameter. - This can be done either by specifying `reference_adtype`, in which case logp - and its gradient will also be calculated with this reference in order to - obtain the ground truth; or by using `expected_value_and_grad`, which is a - tuple of `(logp, gradient)` that the calculated values must match. The - latter is useful if you are testing multiple AD backends and want to avoid - recalculating the ground truth multiple times. + There are several options for this: - The default reference backend is ForwardDiff. If none of these parameters are - specified, ForwardDiff will be used to calculate the ground truth. + - You can explicitly specify the correct value using + [`WithExpectedResult()`](@ref). + - You can compare against the result obtained with a different AD backend + using [`WithBackend(adtype)`](@ref). + - You can disable testing by passing [`NoTest()`](@ref). + - The default is to compare against the result obtained with ForwardDiff, + i.e. `WithBackend(AutoForwardDiff())`. + - `test=false` and `test=true` are synonyms for + `NoTest()` and `WithBackend(AutoForwardDiff())`, respectively. -5. _How to specify the tolerances._ (Only if `test=true`.) +4. _How to specify the tolerances._ (Only if testing is enabled.) - The tolerances for the value and gradient can be set using `value_atol` and - `grad_atol`. These default to 1e-6. + Both absolute and relative tolerances can be specified using the `atol` and + `rtol` keyword arguments respectively. The behaviour of these is similar to + `isapprox()`, i.e. the value and gradient are considered correct if either + atol or rtol is satisfied. The default values are `100*eps()` for `atol` and + `sqrt(eps())` for `rtol`. -6. _Whether to output extra logging information._ + For the most part, it is the `rtol` check that is more meaningful, because + we cannot know the magnitude of logp and its gradient a priori. The `atol` + value is supplied to handle the case where gradients are equal to zero. + +5. _Whether to output extra logging information._ By default, this function prints messages when it runs. To silence it, set `verbose=false`. @@ -189,49 +218,58 @@ thrown as-is. function run_ad( model::Model, adtype::AbstractADType; - test::Bool=true, + test::Union{AbstractADCorrectnessTestSetting,Bool}=WithBackend(), benchmark::Bool=false, - value_atol::AbstractFloat=1e-6, - grad_atol::AbstractFloat=1e-6, - varinfo::AbstractVarInfo=link(VarInfo(model), model), + atol::AbstractFloat=100 * eps(), + rtol::AbstractFloat=sqrt(eps()), + rng::AbstractRNG=default_rng(), + varinfo::AbstractVarInfo=link(VarInfo(rng, model), model), params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, - context::AbstractContext=DefaultContext(), - reference_adtype::AbstractADType=REFERENCE_ADTYPE, - expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing, verbose=true, )::ADResult + # Convert Boolean `test` to an AbstractADCorrectnessTestSetting + if test isa Bool + test = test ? WithBackend() : NoTest() + end + + # Extract parameters if isnothing(params) params = varinfo[:] end params = map(identity, params) # Concretise + # Calculate log-density and gradient with the backend of interest verbose && @info "Running AD on $(model.f) with $(adtype)\n" verbose && println(" params : $(params)") - ldf = LogDensityFunction(model, varinfo, context; adtype=adtype) - + ldf = LogDensityFunction(model, varinfo; adtype=adtype) value, grad = logdensity_and_gradient(ldf, params) + # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 grad = collect(grad) verbose && println(" actual : $((value, grad))") - if test - # Calculate ground truth to compare against - value_true, grad_true = if expected_value_and_grad === nothing - ldf_reference = LogDensityFunction(model, varinfo, context; adtype=reference_adtype) - logdensity_and_gradient(ldf_reference, params) - else - expected_value_and_grad + # Test correctness + if test isa NoTest + value_true = nothing + grad_true = nothing + else + # Get the correct result + if test isa WithExpectedResult + value_true = test.value + grad_true = test.grad + elseif test isa WithBackend + ldf_reference = LogDensityFunction(model, varinfo; adtype=test.adtype) + value_true, grad_true = logdensity_and_gradient(ldf_reference, params) + # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 + grad_true = collect(grad_true) end + # Perform testing verbose && println(" expected : $((value_true, grad_true))") - grad_true = collect(grad_true) - exc() = throw(ADIncorrectException(value, value_true, grad, grad_true)) - isapprox(value, value_true; atol=value_atol) || exc() - isapprox(grad, grad_true; atol=grad_atol) || exc() - else - value_true = nothing - grad_true = nothing + isapprox(value, value_true; atol=atol, rtol=rtol) || exc() + isapprox(grad, grad_true; atol=atol, rtol=rtol) || exc() end + # Benchmark time_vs_primal = if benchmark primal_benchmark = @be (ldf, params) logdensity(_[1], _[2]) grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2]) @@ -245,11 +283,10 @@ function run_ad( return ADResult( model, varinfo, - context, params, adtype, - value_atol, - grad_atol, + atol, + rtol, value_true, grad_true, value, diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 7404a9af7..863db4262 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -3,34 +3,6 @@ # # Utilities for testing contexts. -""" -Context that multiplies each log-prior by mod -used to test whether varwise_logpriors respects child-context. -""" -struct TestLogModifyingChildContext{T,Ctx} <: DynamicPPL.AbstractContext - mod::T - context::Ctx -end -function TestLogModifyingChildContext( - mod=1.2, context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext() -) - return TestLogModifyingChildContext{typeof(mod),typeof(context)}(mod, context) -end - -DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent() -DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context -function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child) - return TestLogModifyingChildContext(context.mod, child) -end -function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, vn, vi) - value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) - return value, logp * context.mod, vi -end -function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi) - logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) - return logp * context.mod, vi -end - # Dummy context to test nested behaviors. struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext context::C @@ -61,7 +33,7 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod # To see change, let's make sure we're using a different leaf context than the current. leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext - PriorContext() + DynamicPPL.DynamicTransformationContext{false}() else DefaultContext() end @@ -91,10 +63,12 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod # TODO(torfjelde): Make the `varinfo` used for testing a kwarg once it makes sense for other varinfos. # Untyped varinfo. varinfo_untyped = DynamicPPL.VarInfo() - @test (DynamicPPL.evaluate!!(model, varinfo_untyped, SamplingContext(context)); true) - @test (DynamicPPL.evaluate!!(model, varinfo_untyped, context); true) + model_with_spl = contextualize(model, SamplingContext(context)) + model_without_spl = contextualize(model, context) + @test DynamicPPL.evaluate!!(model_with_spl, varinfo_untyped) isa Any + @test DynamicPPL.evaluate!!(model_without_spl, varinfo_untyped) isa Any # Typed varinfo. varinfo_typed = DynamicPPL.typed_varinfo(varinfo_untyped) - @test (DynamicPPL.evaluate!!(model, varinfo_typed, SamplingContext(context)); true) - @test (DynamicPPL.evaluate!!(model, varinfo_typed, context); true) + @test DynamicPPL.evaluate!!(model_with_spl, varinfo_typed) isa Any + @test DynamicPPL.evaluate!!(model_without_spl, varinfo_typed) isa Any end diff --git a/src/test_utils/model_interface.jl b/src/test_utils/model_interface.jl index ce79f2302..93aed074c 100644 --- a/src/test_utils/model_interface.jl +++ b/src/test_utils/model_interface.jl @@ -93,7 +93,7 @@ a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. """ function varnames(model::Model) return collect( - keys(last(DynamicPPL.evaluate!!(model, SimpleVarInfo(Dict()), SamplingContext()))) + keys(last(DynamicPPL.evaluate_and_sample!!(model, SimpleVarInfo(Dict())))) ) end diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index c44024863..8ffb7cbdf 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -148,7 +148,7 @@ Simple model for which [`default_transformation`](@ref) returns a [`StaticTransf 1.5 ~ Normal(m, sqrt(s)) 2.0 ~ Normal(m, sqrt(s)) - return (; s, m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s, m, x=[1.5, 2.0]) end function DynamicPPL.default_transformation(::Model{typeof(demo_static_transformation)}) @@ -194,7 +194,7 @@ end m ~ product_distribution(Normal.(0, sqrt.(s))) x ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_dot_assume_observe)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -225,7 +225,7 @@ end end x ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_assume_index_observe)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -248,7 +248,7 @@ end m ~ MvNormal(zero(x), Diagonal(s)) x ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_assume_multivariate_observe)}, s, m) s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) @@ -279,7 +279,7 @@ end x[i] ~ Normal(m[i], sqrt(s[i])) end - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_dot_assume_observe_index)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -304,7 +304,7 @@ end m ~ Normal(0, sqrt(s)) x .~ Normal(m, sqrt(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_assume_dot_observe)}, s, m) return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) @@ -327,7 +327,7 @@ end m ~ MvNormal(zeros(2), Diagonal(s)) [1.5, 2.0] ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true(model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m) s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) @@ -358,7 +358,7 @@ end 1.5 ~ Normal(m[1], sqrt(s[1])) 2.0 ~ Normal(m[2], sqrt(s[2])) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -384,7 +384,7 @@ end 1.5 ~ Normal(m, sqrt(s)) 2.0 ~ Normal(m, sqrt(s)) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) @@ -407,7 +407,7 @@ end m ~ Normal(0, sqrt(s)) [1.5, 2.0] .~ Normal(m, sqrt(s)) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true(model::Model{typeof(demo_assume_dot_observe_literal)}, s, m) return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) @@ -440,7 +440,7 @@ end 1.5 ~ Normal(m[1], sqrt(s[1])) 2.0 ~ Normal(m[2], sqrt(s[2])) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true( model::Model{typeof(demo_assume_submodel_observe_index_literal)}, s, m @@ -476,9 +476,9 @@ end # Submodel likelihood # With to_submodel, we have to have a left-hand side variable to # capture the result, so we just use a dummy variable - _ignore ~ to_submodel(_likelihood_multivariate_observe(s, m, x)) + _ignore ~ to_submodel(_likelihood_multivariate_observe(s, m, x), false) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -505,7 +505,7 @@ end x[:, 1] ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_dot_assume_observe_matrix_index)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -535,7 +535,7 @@ end x[:, 1] ~ MvNormal(m, Diagonal(s_vec)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_assume_matrix_observe_matrix_index)}, s, m) n = length(model.args.x) diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 539872143..542fc17fc 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -37,12 +37,6 @@ function setup_varinfos( svi_untyped = SimpleVarInfo(OrderedDict()) svi_vnv = SimpleVarInfo(DynamicPPL.VarNamedVector()) - # SimpleVarInfo{<:Any,<:Ref} - svi_typed_ref = SimpleVarInfo(example_values, Ref(getlogp(svi_typed))) - svi_untyped_ref = SimpleVarInfo(OrderedDict(), Ref(getlogp(svi_untyped))) - svi_vnv_ref = SimpleVarInfo(DynamicPPL.VarNamedVector(), Ref(getlogp(svi_vnv))) - - lp = getlogp(vi_typed_metadata) varinfos = map(( vi_untyped_metadata, vi_untyped_vnv, @@ -51,12 +45,10 @@ function setup_varinfos( svi_typed, svi_untyped, svi_vnv, - svi_typed_ref, - svi_untyped_ref, - svi_vnv_ref, )) do vi - # Set them all to the same values. - DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp) + # Set them all to the same values and evaluate logp. + vi = update_values!!(vi, example_values, varnames) + last(DynamicPPL.evaluate!!(model, vi)) end if include_threadsafe diff --git a/src/threadsafe.jl b/src/threadsafe.jl index bd1876a19..51c57651d 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -2,11 +2,11 @@ ThreadSafeVarInfo A `ThreadSafeVarInfo` object wraps an [`AbstractVarInfo`](@ref) object and an -array of log probabilities for thread-safe execution of a probabilistic model. +array of accumulators for thread-safe execution of a probabilistic model. """ -struct ThreadSafeVarInfo{V<:AbstractVarInfo,L} <: AbstractVarInfo +struct ThreadSafeVarInfo{V<:AbstractVarInfo,L<:AccumulatorTuple} <: AbstractVarInfo varinfo::V - logps::L + accs_by_thread::Vector{L} end function ThreadSafeVarInfo(vi::AbstractVarInfo) # In ThreadSafeVarInfo we use threadid() to index into the array of logp @@ -18,64 +18,72 @@ function ThreadSafeVarInfo(vi::AbstractVarInfo) # but Mooncake can't differentiate through that. Empirically, nthreads()*2 # seems to provide an upper bound to maxthreadid(), so we use that here. # See https://github.com/TuringLang/DynamicPPL.jl/pull/936 - return ThreadSafeVarInfo( - vi, [Ref(zero(getlogp(vi))) for _ in 1:(Threads.nthreads() * 2)] - ) + accs_by_thread = [map(split, getaccs(vi)) for _ in 1:(Threads.nthreads() * 2)] + return ThreadSafeVarInfo(vi, accs_by_thread) end ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi -const ThreadSafeVarInfoWithRef{V<:AbstractVarInfo} = ThreadSafeVarInfo{ - V,<:AbstractArray{<:Ref} -} - transformation(vi::ThreadSafeVarInfo) = transformation(vi.varinfo) -# Instead of updating the log probability of the underlying variables we -# just update the array of log probabilities. -function acclogp!!(vi::ThreadSafeVarInfo, logp) - vi.logps[Threads.threadid()] += logp - return vi +# Set the accumulator in question in vi.varinfo, and set the thread-specific +# accumulators of the same type to be empty. +function setacc!!(vi::ThreadSafeVarInfo, acc::AbstractAccumulator) + inner_vi = setacc!!(vi.varinfo, acc) + news_accs_by_thread = map(accs -> setacc!!(accs, split(acc)), vi.accs_by_thread) + return ThreadSafeVarInfo(inner_vi, news_accs_by_thread) end -function acclogp!!(vi::ThreadSafeVarInfoWithRef, logp) - vi.logps[Threads.threadid()][] += logp - return vi + +# Get both the main accumulator and the thread-specific accumulators of the same type and +# combine them. +function getacc(vi::ThreadSafeVarInfo, accname::Val) + main_acc = getacc(vi.varinfo, accname) + other_accs = map(accs -> getacc(accs, accname), vi.accs_by_thread) + return foldl(combine, other_accs; init=main_acc) end -# The current log probability of the variables has to be computed from -# both the wrapped variables and the thread-specific log probabilities. -getlogp(vi::ThreadSafeVarInfo) = getlogp(vi.varinfo) + sum(vi.logps) -getlogp(vi::ThreadSafeVarInfoWithRef) = getlogp(vi.varinfo) + sum(getindex, vi.logps) +hasacc(vi::ThreadSafeVarInfo, accname::Val) = hasacc(vi.varinfo, accname) +acckeys(vi::ThreadSafeVarInfo) = acckeys(vi.varinfo) -# TODO: Make remaining methods thread-safe. -function resetlogp!!(vi::ThreadSafeVarInfo) - return ThreadSafeVarInfo(resetlogp!!(vi.varinfo), zero(vi.logps)) +function getaccs(vi::ThreadSafeVarInfo) + # This method is a bit finicky to maintain type stability. For instance, moving the + # accname -> Val(accname) part in the main `map` call makes constant propagation fail + # and this becomes unstable. Do check the effects if you make edits. + accnames = acckeys(vi) + accname_vals = map(Val, accnames) + return AccumulatorTuple(map(anv -> getacc(vi, anv), accname_vals)) end -function resetlogp!!(vi::ThreadSafeVarInfoWithRef) - for x in vi.logps - x[] = zero(x[]) - end - return ThreadSafeVarInfo(resetlogp!!(vi.varinfo), vi.logps) -end -function setlogp!!(vi::ThreadSafeVarInfo, logp) - return ThreadSafeVarInfo(setlogp!!(vi.varinfo, logp), zero(vi.logps)) + +# Calls to map_accumulator(s)!! are thread-specific by default. For any use of them that +# should _not_ be thread-specific a specific method has to be written. +function map_accumulator!!(func::Function, vi::ThreadSafeVarInfo, accname::Val) + tid = Threads.threadid() + vi.accs_by_thread[tid] = map_accumulator(func, vi.accs_by_thread[tid], accname) + return vi end -function setlogp!!(vi::ThreadSafeVarInfoWithRef, logp) - for x in vi.logps - x[] = zero(x[]) - end - return ThreadSafeVarInfo(setlogp!!(vi.varinfo, logp), vi.logps) + +function map_accumulators!!(func::Function, vi::ThreadSafeVarInfo) + tid = Threads.threadid() + vi.accs_by_thread[tid] = map(func, vi.accs_by_thread[tid]) + return vi end -has_varnamedvector(vi::DynamicPPL.ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) +has_varnamedvector(vi::ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) function BangBang.push!!(vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution) return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist) end +# TODO(mhauru) Why these short-circuits? Why not use the thread-specific ones? get_num_produce(vi::ThreadSafeVarInfo) = get_num_produce(vi.varinfo) -increment_num_produce!(vi::ThreadSafeVarInfo) = increment_num_produce!(vi.varinfo) -reset_num_produce!(vi::ThreadSafeVarInfo) = reset_num_produce!(vi.varinfo) -set_num_produce!(vi::ThreadSafeVarInfo, n::Int) = set_num_produce!(vi.varinfo, n) +function increment_num_produce!!(vi::ThreadSafeVarInfo) + return ThreadSafeVarInfo(increment_num_produce!!(vi.varinfo), vi.accs_by_thread) +end +function reset_num_produce!!(vi::ThreadSafeVarInfo) + return ThreadSafeVarInfo(reset_num_produce!!(vi.varinfo), vi.accs_by_thread) +end +function set_num_produce!!(vi::ThreadSafeVarInfo, n::Int) + return ThreadSafeVarInfo(set_num_produce!!(vi.varinfo, n), vi.accs_by_thread) +end syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) @@ -105,17 +113,20 @@ end # Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. # NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure -# consistency between `vi.logps` field and `getlogp(vi.varinfo)`, which accumulates -# to define `getlogp(vi)`. +# consistency between `vi.accs_by_thread` field and `getacc(vi.varinfo)`, which accumulates +# to define `getacc(vi)`. function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) + model = contextualize( + model, setleafcontext(model.context, DynamicTransformationContext{false}()) + ) + return settrans!!(last(evaluate!!(model, vi)), t) end function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - return settrans!!( - last(evaluate!!(model, vi, DynamicTransformationContext{true}())), - NoTransformation(), + model = contextualize( + model, setleafcontext(model.context, DynamicTransformationContext{true}()) ) + return settrans!!(last(evaluate!!(model, vi)), NoTransformation()) end function link(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) @@ -141,9 +152,9 @@ end function maybe_invlink_before_eval!!(vi::ThreadSafeVarInfo, model::Model) # Defer to the wrapped `AbstractVarInfo` object. - # NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the - # `getlogp(vi.varinfo)` hence the log-absdet-jacobian term will correctly be included in - # the `getlogp(vi)`. + # NOTE: When computing `getacc` for `ThreadSafeVarInfo` we do include the + # `getacc(vi.varinfo)` hence the log-absdet-jacobian term will correctly be included in + # the `getlogprior(vi)`. return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!(vi.varinfo, model) end @@ -180,6 +191,23 @@ function BangBang.empty!!(vi::ThreadSafeVarInfo) return resetlogp!!(Accessors.@set(vi.varinfo = empty!!(vi.varinfo))) end +function resetlogp!!(vi::ThreadSafeVarInfo) + vi = Accessors.@set vi.varinfo = resetlogp!!(vi.varinfo) + for i in eachindex(vi.accs_by_thread) + if hasacc(vi, Val(:LogPrior)) + vi.accs_by_thread[i] = map_accumulator( + zero, vi.accs_by_thread[i], Val(:LogPrior) + ) + end + if hasacc(vi, Val(:LogLikelihood)) + vi.accs_by_thread[i] = map_accumulator( + zero, vi.accs_by_thread[i], Val(:LogLikelihood) + ) + end + end + return vi +end + values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo) values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T) diff --git a/src/transforming.jl b/src/transforming.jl index 429562ec8..e3da0ff29 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -27,18 +27,48 @@ function tilde_assume( # Only transform if `!isinverse` since `vi[vn, right]` # already performs the inverse transformation if it's transformed. r_transformed = isinverse ? r : link_transform(right)(r) - return r, lp, setindex!!(vi, r_transformed, vn) + if hasacc(vi, Val(:LogPrior)) + vi = acclogprior!!(vi, lp) + end + return r, setindex!!(vi, r_transformed, vn) +end + +function tilde_observe!!(::DynamicTransformationContext, right, left, vn, vi) + return tilde_observe!!(DefaultContext(), right, left, vn, vi) end function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) - return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) + return _transform!!(t, DynamicTransformationContext{false}(), vi, model) end function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) - return settrans!!( - last(evaluate!!(model, vi, DynamicTransformationContext{true}())), - NoTransformation(), - ) + return _transform!!(NoTransformation(), DynamicTransformationContext{true}(), vi, model) +end + +function _transform!!( + t::AbstractTransformation, + ctx::DynamicTransformationContext, + vi::AbstractVarInfo, + model::Model, +) + # To transform using DynamicTransformationContext, we evaluate the model using that as the leaf context: + model = contextualize(model, setleafcontext(model.context, ctx)) + # but we do not need to use any accumulators other than LogPriorAccumulator + # (which is affected by the Jacobian of the transformation). + accs = getaccs(vi) + has_logprior = haskey(accs, Val(:LogPrior)) + if has_logprior + old_logprior = getacc(accs, Val(:LogPrior)) + vi = setaccs!!(vi, (old_logprior,)) + end + vi = settrans!!(last(evaluate!!(model, vi)), t) + # Restore the accumulators. + if has_logprior + new_logprior = getacc(vi, Val(:LogPrior)) + accs = setacc!!(accs, new_logprior) + end + vi = setaccs!!(vi, accs) + return vi end function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) diff --git a/src/utils.jl b/src/utils.jl index 73a8b48b9..0f4d98b11 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -18,23 +18,29 @@ const LogProbType = float(Real) """ @addlogprob!(ex) -Add the result of the evaluation of `ex` to the joint log probability. +Add a term to the log joint. -# Examples +If `ex` evaluates to a `NamedTuple` with keys `:loglikelihood` and/or `:logprior`, the +values are added to the log likelihood and log prior respectively. + +If `ex` evaluates to a number it is added to the log likelihood. -This macro allows you to [include arbitrary terms in the likelihood](https://github.com/TuringLang/Turing.jl/issues/1332) +# Examples ```jldoctest; setup = :(using Distributions) -julia> myloglikelihood(x, μ) = loglikelihood(Normal(μ, 1), x); +julia> mylogjoint(x, μ) = (; loglikelihood=loglikelihood(Normal(μ, 1), x), logprior=1.0); julia> @model function demo(x) μ ~ Normal() - @addlogprob! myloglikelihood(x, μ) + @addlogprob! mylogjoint(x, μ) end; julia> x = [1.3, -2.1]; -julia> loglikelihood(demo(x), (μ=0.2,)) ≈ myloglikelihood(x, 0.2) +julia> loglikelihood(demo(x), (μ=0.2,)) ≈ mylogjoint(x, 0.2).loglikelihood +true + +julia> logprior(demo(x), (μ=0.2,)) ≈ logpdf(Normal(), 0.2) + mylogjoint(x, 0.2).logprior true ``` @@ -44,7 +50,7 @@ and to [reject samples](https://github.com/TuringLang/Turing.jl/issues/1328): julia> @model function demo(x) m ~ MvNormal(zero(x), I) if dot(m, x) < 0 - @addlogprob! -Inf + @addlogprob! (; loglikelihood=-Inf) # Exit the model evaluation early return end @@ -55,37 +61,22 @@ julia> @model function demo(x) julia> logjoint(demo([-2.1]), (m=[0.2],)) == -Inf true ``` - -!!! note - The `@addlogprob!` macro increases the accumulated log probability regardless of the evaluation context, - i.e., regardless of whether you evaluate the log prior, the log likelihood or the log joint density. - If you would like to avoid this behaviour you should check the evaluation context. - It can be accessed with the internal variable `__context__`. - For instance, in the following example the log density is not accumulated when only the log prior is computed: - ```jldoctest; setup = :(using Distributions) - julia> myloglikelihood(x, μ) = loglikelihood(Normal(μ, 1), x); - - julia> @model function demo(x) - μ ~ Normal() - if DynamicPPL.leafcontext(__context__) !== PriorContext() - @addlogprob! myloglikelihood(x, μ) - end - end; - - julia> x = [1.3, -2.1]; - - julia> logprior(demo(x), (μ=0.2,)) ≈ logpdf(Normal(), 0.2) - true - - julia> loglikelihood(demo(x), (μ=0.2,)) ≈ myloglikelihood(x, 0.2) - true - ``` """ macro addlogprob!(ex) return quote - $(esc(:(__varinfo__))) = acclogp!!( - $(esc(:(__context__))), $(esc(:(__varinfo__))), $(esc(ex)) - ) + val = $(esc(ex)) + vi = $(esc(:(__varinfo__))) + if val isa Number + if hasacc(vi, Val(:LogLikelihood)) + $(esc(:(__varinfo__))) = accloglikelihood!!($(esc(:(__varinfo__))), val) + end + elseif val isa NamedTuple + $(esc(:(__varinfo__))) = acclogp!!( + $(esc(:(__varinfo__))), val; ignore_missing_accumulator=true + ) + else + error("logp must be a Number or a NamedTuple.") + end end end diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index d3bfd697a..4922ddbb0 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -1,16 +1,7 @@ -struct TrackedValue{T} - value::T -end - -is_tracked_value(::TrackedValue) = true -is_tracked_value(::Any) = false - -check_tilde_rhs(x::TrackedValue) = x - """ - ValuesAsInModelContext + ValuesAsInModelAccumulator <: AbstractAccumulator -A context that is used by [`values_as_in_model`](@ref) to obtain values +An accumulator that is used by [`values_as_in_model`](@ref) to obtain values of the model parameters as they are in the model. This is particularly useful when working in unconstrained space, but one @@ -19,79 +10,49 @@ wants to extract the realization of a model in a constrained space. # Fields $(TYPEDFIELDS) """ -struct ValuesAsInModelContext{C<:AbstractContext} <: AbstractContext +struct ValuesAsInModelAccumulator <: AbstractAccumulator "values that are extracted from the model" values::OrderedDict "whether to extract variables on the LHS of :=" include_colon_eq::Bool - "child context" - context::C end -function ValuesAsInModelContext(include_colon_eq, context::AbstractContext) - return ValuesAsInModelContext(OrderedDict(), include_colon_eq, context) +function ValuesAsInModelAccumulator(include_colon_eq) + return ValuesAsInModelAccumulator(OrderedDict(), include_colon_eq) end -NodeTrait(::ValuesAsInModelContext) = IsParent() -childcontext(context::ValuesAsInModelContext) = context.context -function setchildcontext(context::ValuesAsInModelContext, child) - return ValuesAsInModelContext(context.values, context.include_colon_eq, child) -end +accumulator_name(::Type{<:ValuesAsInModelAccumulator}) = :ValuesAsInModel -is_extracting_values(context::ValuesAsInModelContext) = context.include_colon_eq -function is_extracting_values(context::AbstractContext) - return is_extracting_values(NodeTrait(context), context) +function split(acc::ValuesAsInModelAccumulator) + return ValuesAsInModelAccumulator(empty(acc.values), acc.include_colon_eq) end -is_extracting_values(::IsParent, ::AbstractContext) = false -is_extracting_values(::IsLeaf, ::AbstractContext) = false - -function Base.push!(context::ValuesAsInModelContext, vn::VarName, value) - return setindex!(context.values, copy(value), prefix(context, vn)) +function combine(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumulator) + if acc1.include_colon_eq != acc2.include_colon_eq + msg = "Cannot combine accumulators with different include_colon_eq values." + throw(ArgumentError(msg)) + end + return ValuesAsInModelAccumulator( + merge(acc1.values, acc2.values), acc1.include_colon_eq + ) end -function broadcast_push!(context::ValuesAsInModelContext, vns, values) - return push!.((context,), vns, values) +function Base.push!(acc::ValuesAsInModelAccumulator, vn::VarName, val) + setindex!(acc.values, deepcopy(val), vn) + return acc end -# This will be hit if we're broadcasting an `AbstractMatrix` over a `MultivariateDistribution`. -function broadcast_push!( - context::ValuesAsInModelContext, vns::AbstractVector, values::AbstractMatrix -) - for (vn, col) in zip(vns, eachcol(values)) - push!(context, vn, col) - end +function is_extracting_values(vi::AbstractVarInfo) + return hasacc(vi, Val(:ValuesAsInModel)) && + getacc(vi, Val(:ValuesAsInModel)).include_colon_eq end -# `tilde_asssume` -function tilde_assume(context::ValuesAsInModelContext, right, vn, vi) - if is_tracked_value(right) - value = right.value - logp = zero(getlogp(vi)) - else - value, logp, vi = tilde_assume(childcontext(context), right, vn, vi) - end - # Save the value. - push!(context, vn, value) - # Save the value. - # Pass on. - return value, logp, vi -end -function tilde_assume( - rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi -) - if is_tracked_value(right) - value = right.value - logp = zero(getlogp(vi)) - else - value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi) - end - # Save the value. - push!(context, vn, value) - # Pass on. - return value, logp, vi +function accumulate_assume!!(acc::ValuesAsInModelAccumulator, val, logjac, vn, right) + return push!(acc, vn, val) end +accumulate_observe!!(acc::ValuesAsInModelAccumulator, right, left, vn) = acc + """ - values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo[, context::AbstractContext]) + values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo) Get the values of `varinfo` as they would be seen in the model. @@ -108,8 +69,6 @@ space at the cost of additional model evaluations. - `model::Model`: model to extract realizations from. - `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`. - `varinfo::AbstractVarInfo`: variable information to use for the extraction. -- `context::AbstractContext`: base context to use for the extraction. Defaults - to `DynamicPPL.DefaultContext()`. # Examples @@ -163,13 +122,8 @@ julia> # Approach 2: Extract realizations using `values_as_in_model`. true ``` """ -function values_as_in_model( - model::Model, - include_colon_eq::Bool, - varinfo::AbstractVarInfo, - context::AbstractContext=DefaultContext(), -) - context = ValuesAsInModelContext(include_colon_eq, context) - evaluate!!(model, varinfo, context) - return context.values +function values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo) + varinfo = setaccs!!(deepcopy(varinfo), (ValuesAsInModelAccumulator(include_colon_eq),)) + varinfo = last(evaluate!!(model, varinfo)) + return getacc(varinfo, Val(:ValuesAsInModel)).values end diff --git a/src/varinfo.jl b/src/varinfo.jl index bc59c67a6..b3380e7f9 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -69,10 +69,9 @@ end ########### """ - struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo + struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo metadata::Tmeta - logp::Base.RefValue{Tlogp} - num_produce::Base.RefValue{Int} + accs::Accs end A light wrapper over some kind of metadata. @@ -98,17 +97,19 @@ Note that for NTVarInfo, it is the user's responsibility to ensure that each symbol is visited at least once during model evaluation, regardless of any stochastic branching. """ -struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo +struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo metadata::Tmeta - logp::Base.RefValue{Tlogp} - num_produce::Base.RefValue{Int} + accs::Accs end -VarInfo(meta=Metadata()) = VarInfo(meta, Ref{LogProbType}(0.0), Ref(0)) +function VarInfo(meta=Metadata()) + return VarInfo(meta, default_accumulators()) +end + """ - VarInfo([rng, ]model[, sampler, context]) + VarInfo([rng, ]model[, sampler]) Generate a `VarInfo` object for the given `model`, by evaluating it once using -the given `rng`, `sampler`, and `context`. +the given `rng`, `sampler`. !!! warning @@ -121,28 +122,12 @@ the given `rng`, `sampler`, and `context`. instead. """ function VarInfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() ) - return typed_varinfo(rng, model, sampler, context) + return typed_varinfo(rng, model, sampler) end -function VarInfo( - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - # No rng - return VarInfo(Random.default_rng(), model, sampler, context) -end -function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) - # No sampler - return VarInfo(rng, model, SampleFromPrior(), context) -end -function VarInfo(model::Model, context::AbstractContext) - # No sampler, no rng - return VarInfo(Random.default_rng(), model, SampleFromPrior(), context) +function VarInfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) + return VarInfo(Random.default_rng(), model, sampler) end const UntypedVectorVarInfo = VarInfo{<:VarNamedVector} @@ -199,42 +184,23 @@ end ######################## """ - untyped_varinfo([rng, ]model[, sampler, context, metadata]) + untyped_varinfo([rng, ]model[, sampler]) -Return a VarInfo object for the given `model` and `context`, which has just a -single `Metadata` as its metadata field. +Construct a VarInfo object for the given `model`, which has just a single +`Metadata` as its metadata field. # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object - `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. -- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. """ function untyped_varinfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() ) - varinfo = VarInfo(Metadata()) - context = SamplingContext(rng, sampler, context) - return last(evaluate!!(model, varinfo, context)) + return last(evaluate_and_sample!!(rng, model, VarInfo(Metadata()), sampler)) end -function untyped_varinfo( - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - # No rng - return untyped_varinfo(Random.default_rng(), model, sampler, context) -end -function untyped_varinfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) - # No sampler - return untyped_varinfo(rng, model, SampleFromPrior(), context) -end -function untyped_varinfo(model::Model, context::AbstractContext) - # No sampler, no rng - return untyped_varinfo(model, SampleFromPrior(), context) +function untyped_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) + return untyped_varinfo(Random.default_rng(), model, sampler) end """ @@ -285,10 +251,8 @@ function typed_varinfo(vi::UntypedVarInfo) ), ) end - logp = getlogp(vi) - num_produce = get_num_produce(vi) nt = NamedTuple{syms_tuple}(Tuple(new_metas)) - return VarInfo(nt, Ref(logp), Ref(num_produce)) + return VarInfo(nt, deepcopy(vi.accs)) end function typed_varinfo(vi::NTVarInfo) # This function preserves the behaviour of typed_varinfo(vi) where vi is @@ -299,135 +263,76 @@ function typed_varinfo(vi::NTVarInfo) return vi end """ - typed_varinfo([rng, ]model[, sampler, context, metadata]) + typed_varinfo([rng, ]model[, sampler]) -Return a VarInfo object for the given `model` and `context`, which has a NamedTuple of +Return a VarInfo object for the given `model`, which has a NamedTuple of `Metadata` structs as its metadata field. # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object - `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. -- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. """ function typed_varinfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - return typed_varinfo(untyped_varinfo(rng, model, sampler, context)) -end -function typed_varinfo( - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() ) - # No rng - return typed_varinfo(Random.default_rng(), model, sampler, context) + return typed_varinfo(untyped_varinfo(rng, model, sampler)) end -function typed_varinfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) - # No sampler - return typed_varinfo(rng, model, SampleFromPrior(), context) -end -function typed_varinfo(model::Model, context::AbstractContext) - # No sampler, no rng - return typed_varinfo(model, SampleFromPrior(), context) +function typed_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) + return typed_varinfo(Random.default_rng(), model, sampler) end """ - untyped_vector_varinfo([rng, ]model[, sampler, context, metadata]) + untyped_vector_varinfo([rng, ]model[, sampler]) -Return a VarInfo object for the given `model` and `context`, which has just a -single `VarNamedVector` as its metadata field. +Return a VarInfo object for the given `model`, which has just a single +`VarNamedVector` as its metadata field. # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object - `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. -- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. """ function untyped_vector_varinfo(vi::UntypedVarInfo) md = metadata_to_varnamedvector(vi.metadata) - lp = getlogp(vi) - return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) -end -function untyped_vector_varinfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - return untyped_vector_varinfo(untyped_varinfo(rng, model, sampler, context)) + return VarInfo(md, deepcopy(vi.accs)) end function untyped_vector_varinfo( - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() ) - # No rng - return untyped_vector_varinfo(Random.default_rng(), model, sampler, context) + return untyped_vector_varinfo(untyped_varinfo(rng, model, sampler)) end -function untyped_vector_varinfo( - rng::Random.AbstractRNG, model::Model, context::AbstractContext -) - # No sampler - return untyped_vector_varinfo(rng, model, SampleFromPrior(), context) -end -function untyped_vector_varinfo(model::Model, context::AbstractContext) - # No sampler, no rng - return untyped_vector_varinfo(model, SampleFromPrior(), context) +function untyped_vector_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) + return untyped_vector_varinfo(Random.default_rng(), model, sampler) end """ - typed_vector_varinfo([rng, ]model[, sampler, context, metadata]) + typed_vector_varinfo([rng, ]model[, sampler]) -Return a VarInfo object for the given `model` and `context`, which has a -NamedTuple of `VarNamedVector`s as its metadata field. +Return a VarInfo object for the given `model`, which has a NamedTuple of +`VarNamedVector`s as its metadata field. # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object - `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. -- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. """ function typed_vector_varinfo(vi::NTVarInfo) md = map(metadata_to_varnamedvector, vi.metadata) - lp = getlogp(vi) - return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) + return VarInfo(md, deepcopy(vi.accs)) end function typed_vector_varinfo(vi::UntypedVectorVarInfo) new_metas = group_by_symbol(vi.metadata) - logp = getlogp(vi) - num_produce = get_num_produce(vi) nt = NamedTuple(new_metas) - return VarInfo(nt, Ref(logp), Ref(num_produce)) + return VarInfo(nt, deepcopy(vi.accs)) end function typed_vector_varinfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() ) - return typed_vector_varinfo(untyped_vector_varinfo(rng, model, sampler, context)) + return typed_vector_varinfo(untyped_vector_varinfo(rng, model, sampler)) end -function typed_vector_varinfo( - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - # No rng - return typed_vector_varinfo(Random.default_rng(), model, sampler, context) -end -function typed_vector_varinfo( - rng::Random.AbstractRNG, model::Model, context::AbstractContext -) - # No sampler - return typed_vector_varinfo(rng, model, SampleFromPrior(), context) -end -function typed_vector_varinfo(model::Model, context::AbstractContext) - # No sampler, no rng - return typed_vector_varinfo(model, SampleFromPrior(), context) +function typed_vector_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) + return typed_vector_varinfo(Random.default_rng(), model, sampler) end """ @@ -441,13 +346,22 @@ vector_length(md::Metadata) = sum(length, md.ranges) function unflatten(vi::VarInfo, x::AbstractVector) md = unflatten_metadata(vi.metadata, x) - # Note that use of RefValue{eltype(x)} rather than Ref is necessary to deal with cases - # where e.g. x is a type gradient of some AD backend. - return VarInfo( - md, - Base.RefValue{float_type_with_fallback(eltype(x))}(getlogp(vi)), - Ref(get_num_produce(vi)), + # Use of float_type_with_fallback(eltype(x)) is necessary to deal with cases where x is + # a gradient type of some AD backend. + # TODO(mhauru) How could we do this more cleanly? The problem case is map_accumulator!! + # for ThreadSafeVarInfo. In that one, if the map produces e.g a ForwardDiff.Dual, but + # the accumulators in the VarInfo are plain floats, we error since we can't change the + # element type of ThreadSafeVarInfo.accs_by_thread. However, doing this conversion here + # messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just + # plain ugly and hacky. + # The below line is finicky for type stability. For instance, assigning the eltype to + # convert to into an intermediate variable makes this unstable (constant propagation) + # fails. Take care when editing. + accs = map( + acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), + deepcopy(getaccs(vi)), ) + return VarInfo(md, accs) end # We would call this `unflatten` if not for `unflatten` having a method for NamedTuples in @@ -529,7 +443,7 @@ end function subset(varinfo::VarInfo, vns::AbstractVector{<:VarName}) metadata = subset(varinfo.metadata, vns) - return VarInfo(metadata, deepcopy(varinfo.logp), deepcopy(varinfo.num_produce)) + return VarInfo(metadata, deepcopy(varinfo.accs)) end function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName}) @@ -618,9 +532,7 @@ end function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) - return VarInfo( - metadata, Ref(getlogp(varinfo_right)), Ref(get_num_produce(varinfo_right)) - ) + return VarInfo(metadata, deepcopy(varinfo_right.accs)) end function merge_metadata(vnv_left::VarNamedVector, vnv_right::VarNamedVector) @@ -976,8 +888,8 @@ end function BangBang.empty!!(vi::VarInfo) _empty!(vi.metadata) - resetlogp!!(vi) - reset_num_produce!(vi) + vi = resetlogp!!(vi) + vi = reset_num_produce!!(vi) return vi end @@ -1011,46 +923,8 @@ end istrans(vi::VarInfo, vn::VarName) = istrans(getmetadata(vi, vn), vn) istrans(md::Metadata, vn::VarName) = is_flagged(md, vn, "trans") -getlogp(vi::VarInfo) = vi.logp[] - -function setlogp!!(vi::VarInfo, logp) - vi.logp[] = logp - return vi -end - -function acclogp!!(vi::VarInfo, logp) - vi.logp[] += logp - return vi -end - -""" - get_num_produce(vi::VarInfo) - -Return the `num_produce` of `vi`. -""" -get_num_produce(vi::VarInfo) = vi.num_produce[] - -""" - set_num_produce!(vi::VarInfo, n::Int) - -Set the `num_produce` field of `vi` to `n`. -""" -set_num_produce!(vi::VarInfo, n::Int) = vi.num_produce[] = n - -""" - increment_num_produce!(vi::VarInfo) - -Add 1 to `num_produce` in `vi`. -""" -increment_num_produce!(vi::VarInfo) = vi.num_produce[] += 1 - -""" - reset_num_produce!(vi::VarInfo) - -Reset the value of `num_produce` the log of the joint probability of the observed data -and parameters sampled in `vi` to 0. -""" -reset_num_produce!(vi::VarInfo) = set_num_produce!(vi, 0) +getaccs(vi::VarInfo) = vi.accs +setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs # Need to introduce the _isempty to avoid type piracy of isempty(::NamedTuple). isempty(vi::VarInfo) = _isempty(vi.metadata) @@ -1064,7 +938,7 @@ function link!!(::DynamicTransformation, vi::NTVarInfo, model::Model) vns = all_varnames_grouped_by_symbol(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) - _link!(vi, vns) + vi = _link!!(vi, vns) return vi end @@ -1072,7 +946,7 @@ function link!!(::DynamicTransformation, vi::VarInfo, model::Model) vns = keys(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) - _link!(vi, vns) + vi = _link!!(vi, vns) return vi end @@ -1085,8 +959,7 @@ end function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) - # Call `_link!` instead of `link!` to avoid deprecation warning. - _link!(vi, vns) + vi = _link!!(vi, vns) return vi end @@ -1101,27 +974,28 @@ function link!!( return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) end -function _link!(vi::UntypedVarInfo, vns) +function _link!!(vi::UntypedVarInfo, vns) # TODO: Change to a lazy iterator over `vns` if ~istrans(vi, vns[1]) for vn in vns f = internal_to_linked_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, true, vn) + vi = _inner_transform!(vi, vn, f) + vi = settrans!!(vi, true, vn) end + return vi else @warn("[DynamicPPL] attempt to link a linked vi") end end -# If we try to _link! a NTVarInfo with a Tuple of VarNames, first convert it to a +# If we try to _link!! a NTVarInfo with a Tuple of VarNames, first convert it to a # NamedTuple that matches the structure of the NTVarInfo. -function _link!(vi::NTVarInfo, vns::VarNameTuple) - return _link!(vi, group_varnames_by_symbol(vns)) +function _link!!(vi::NTVarInfo, vns::VarNameTuple) + return _link!!(vi, group_varnames_by_symbol(vns)) end -function _link!(vi::NTVarInfo, vns::NamedTuple) - return _link!(vi.metadata, vi, vns) +function _link!!(vi::NTVarInfo, vns::NamedTuple) + return _link!!(vi.metadata, vi, vns) end """ @@ -1133,7 +1007,7 @@ function filter_subsumed(filter_vns, filtered_vns) return filter(x -> any(subsumes(y, x) for y in filter_vns), filtered_vns) end -@generated function _link!( +@generated function _link!!( ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} ) where {metadata_names,vns_names} expr = Expr(:block) @@ -1151,8 +1025,8 @@ end # Iterate over all `f_vns` and transform for vn in f_vns f = internal_to_linked_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, true, vn) + vi = _inner_transform!(vi, vn, f) + vi = settrans!!(vi, true, vn) end else @warn("[DynamicPPL] attempt to link a linked vi") @@ -1161,6 +1035,7 @@ end end, ) end + push!(expr.args, :(return vi)) return expr end @@ -1168,8 +1043,7 @@ function invlink!!(::DynamicTransformation, vi::NTVarInfo, model::Model) vns = all_varnames_grouped_by_symbol(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) - # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. - _invlink!(vi, vns) + vi = _invlink!!(vi, vns) return vi end @@ -1177,7 +1051,7 @@ function invlink!!(::DynamicTransformation, vi::VarInfo, model::Model) vns = keys(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) - _invlink!(vi, vns) + vi = _invlink!!(vi, vns) return vi end @@ -1190,8 +1064,7 @@ end function invlink!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) - # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. - _invlink!(vi, vns) + vi = _invlink!!(vi, vns) return vi end @@ -1214,29 +1087,30 @@ function maybe_invlink_before_eval!!(vi::VarInfo, model::Model) return maybe_invlink_before_eval!!(t, vi, model) end -function _invlink!(vi::UntypedVarInfo, vns) +function _invlink!!(vi::UntypedVarInfo, vns) if istrans(vi, vns[1]) for vn in vns f = linked_internal_to_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, false, vn) + vi = _inner_transform!(vi, vn, f) + vi = settrans!!(vi, false, vn) end + return vi else @warn("[DynamicPPL] attempt to invlink an invlinked vi") end end -# If we try to _invlink! a NTVarInfo with a Tuple of VarNames, first convert it to a +# If we try to _invlink!! a NTVarInfo with a Tuple of VarNames, first convert it to a # NamedTuple that matches the structure of the NTVarInfo. -function _invlink!(vi::NTVarInfo, vns::VarNameTuple) - return _invlink!(vi.metadata, vi, group_varnames_by_symbol(vns)) +function _invlink!!(vi::NTVarInfo, vns::VarNameTuple) + return _invlink!!(vi.metadata, vi, group_varnames_by_symbol(vns)) end -function _invlink!(vi::NTVarInfo, vns::NamedTuple) - return _invlink!(vi.metadata, vi, vns) +function _invlink!!(vi::NTVarInfo, vns::NamedTuple) + return _invlink!!(vi.metadata, vi, vns) end -@generated function _invlink!( +@generated function _invlink!!( ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} ) where {metadata_names,vns_names} expr = Expr(:block) @@ -1254,8 +1128,8 @@ end # Iterate over all `f_vns` and transform for vn in f_vns f = linked_internal_to_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, false, vn) + vi = _inner_transform!(vi, vn, f) + vi = settrans!!(vi, false, vn) end else @warn("[DynamicPPL] attempt to invlink an invlinked vi") @@ -1263,6 +1137,7 @@ end end, ) end + push!(expr.args, :(return vi)) return expr end @@ -1279,7 +1154,7 @@ function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) setrange!(md, vn, start:(start + length(yvec) - 1)) # Set the new value. setval!(md, yvec, vn) - acclogp!!(vi, -logjac) + vi = acclogprior!!(vi, -logjac) return vi end @@ -1314,8 +1189,10 @@ end function _link(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) - md = _link_metadata!!(model, varinfo, varinfo.metadata, vns) - return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) + md, logjac = _link_metadata!!(model, varinfo, varinfo.metadata, vns) + new_varinfo = VarInfo(md, varinfo.accs) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + return new_varinfo end # If we try to _link a NTVarInfo with a Tuple of VarNames, first convert it to a @@ -1326,8 +1203,10 @@ end function _link(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) - md = _link_metadata!(model, varinfo, varinfo.metadata, vns) - return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) + md, logjac = _link_metadata!(model, varinfo, varinfo.metadata, vns) + new_varinfo = VarInfo(md, varinfo.accs) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + return new_varinfo end @generated function _link_metadata!( @@ -1336,20 +1215,39 @@ end metadata::NamedTuple{metadata_names}, vns::NamedTuple{vns_names}, ) where {metadata_names,vns_names} - vals = Expr(:tuple) + expr = quote + cumulative_logjac = zero(LogProbType) + end + mds = Expr(:tuple) for f in metadata_names if f in vns_names - push!(vals.args, :(_link_metadata!!(model, varinfo, metadata.$f, vns.$f))) + push!( + mds.args, + quote + begin + md, logjac = _link_metadata!!(model, varinfo, metadata.$f, vns.$f) + cumulative_logjac += logjac + md + end + end, + ) else - push!(vals.args, :(metadata.$f)) + push!(mds.args, :(metadata.$f)) end end - return :(NamedTuple{$metadata_names}($vals)) + push!( + expr.args, + quote + NamedTuple{$metadata_names}($mds), cumulative_logjac + end, + ) + return expr end function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns + cumulative_logjac = zero(LogProbType) # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn @@ -1367,7 +1265,7 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_ # Vectorize value. yvec = tovec(y) # Accumulate the log-abs-det jacobian correction. - acclogp!!(varinfo, -logjac) + cumulative_logjac += logjac # Mark as transformed. settrans!!(varinfo, true, vn) # Return the vectorized transformed value. @@ -1392,7 +1290,8 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_ metadata.dists, metadata.orders, metadata.flags, - ) + ), + cumulative_logjac end function _link_metadata!!( @@ -1400,6 +1299,7 @@ function _link_metadata!!( ) vns = target_vns === nothing ? keys(metadata) : target_vns dists = extract_priors(model, varinfo) + cumulative_logjac = zero(LogProbType) for vn in vns # First transform from however the variable is stored in vnv to the model # representation. @@ -1412,11 +1312,11 @@ function _link_metadata!!( val_new, logjac2 = with_logabsdet_jacobian(transform_to_linked, val_orig) # TODO(mhauru) We are calling a !! function but ignoring the return value. # Fix this when attending to issue #653. - acclogp!!(varinfo, -logjac1 - logjac2) + cumulative_logjac += logjac1 + logjac2 metadata = setindex_internal!!(metadata, val_new, vn, transform_from_linked) settrans!(metadata, true, vn) end - return metadata + return metadata, cumulative_logjac end function invlink(::DynamicTransformation, vi::NTVarInfo, model::Model) @@ -1452,11 +1352,10 @@ end function _invlink(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) - return VarInfo( - _invlink_metadata!!(model, varinfo, varinfo.metadata, vns), - Base.Ref(getlogp(varinfo)), - Ref(get_num_produce(varinfo)), - ) + md, logjac = _invlink_metadata!!(model, varinfo, varinfo.metadata, vns) + new_varinfo = VarInfo(md, varinfo.accs) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + return new_varinfo end # If we try to _invlink a NTVarInfo with a Tuple of VarNames, first convert it to a @@ -1467,8 +1366,10 @@ end function _invlink(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) - md = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) - return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) + md, logjac = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) + new_varinfo = VarInfo(md, varinfo.accs) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + return new_varinfo end @generated function _invlink_metadata!( @@ -1477,20 +1378,41 @@ end metadata::NamedTuple{metadata_names}, vns::NamedTuple{vns_names}, ) where {metadata_names,vns_names} - vals = Expr(:tuple) + expr = quote + cumulative_logjac = zero(LogProbType) + end + mds = Expr(:tuple) for f in metadata_names if (f in vns_names) - push!(vals.args, :(_invlink_metadata!!(model, varinfo, metadata.$f, vns.$f))) + push!( + mds.args, + quote + begin + md, logjac = _invlink_metadata!!( + model, varinfo, metadata.$f, vns.$f + ) + cumulative_logjac += logjac + md + end + end, + ) else - push!(vals.args, :(metadata.$f)) + push!(mds.args, :(metadata.$f)) end end - return :(NamedTuple{$metadata_names}($vals)) + push!( + expr.args, + quote + (NamedTuple{$metadata_names}($mds), cumulative_logjac) + end, + ) + return expr end function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns + cumulative_logjac = zero(LogProbType) # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn @@ -1509,7 +1431,7 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ # Vectorize value. xvec = tovec(x) # Accumulate the log-abs-det jacobian correction. - acclogp!!(varinfo, -logjac) + cumulative_logjac += logjac # Mark as no longer transformed. settrans!!(varinfo, false, vn) # Return the vectorized transformed value. @@ -1534,24 +1456,26 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ metadata.dists, metadata.orders, metadata.flags, - ) + ), + cumulative_logjac end function _invlink_metadata!!( ::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns ) vns = target_vns === nothing ? keys(metadata) : target_vns + cumulative_logjac = zero(LogProbType) for vn in vns transform = gettransform(metadata, vn) old_val = getindex_internal(metadata, vn) new_val, logjac = with_logabsdet_jacobian(transform, old_val) # TODO(mhauru) We are calling a !! function but ignoring the return value. - acclogp!!(varinfo, -logjac) + cumulative_logjac += logjac new_transform = from_vec_transform(new_val) metadata = setindex_internal!!(metadata, tovec(new_val), vn, new_transform) settrans!(metadata, false, vn) end - return metadata + return metadata, cumulative_logjac end # TODO(mhauru) The treatment of the case when some variables are linked and others are not @@ -1708,19 +1632,35 @@ function Base.haskey(vi::NTVarInfo, vn::VarName) end function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) - vi_str = """ - /======================================================================= - | VarInfo - |----------------------------------------------------------------------- - | Varnames : $(string(vi.metadata.vns)) - | Range : $(vi.metadata.ranges) - | Vals : $(vi.metadata.vals) - | Orders : $(vi.metadata.orders) - | Logp : $(getlogp(vi)) - | #produce : $(get_num_produce(vi)) - | flags : $(vi.metadata.flags) - \\======================================================================= - """ + lines = Tuple{String,Any}[ + ("VarNames", vi.metadata.vns), + ("Range", vi.metadata.ranges), + ("Vals", vi.metadata.vals), + ("Orders", vi.metadata.orders), + ] + for accname in acckeys(vi) + push!(lines, (string(accname), getacc(vi, Val(accname)))) + end + push!(lines, ("flags", vi.metadata.flags)) + max_name_length = maximum(map(length ∘ first, lines)) + fmt = Printf.Format("%-$(max_name_length)s") + vi_str = ( + """ + /======================================================================= + | VarInfo + |----------------------------------------------------------------------- + """ * + prod( + map(lines) do (name, value) + """ + | $(Printf.format(fmt, name)) : $(value) + """ + end, + ) * + """ + \\======================================================================= + """ + ) return print(io, vi_str) end @@ -1750,7 +1690,11 @@ end function Base.show(io::IO, vi::UntypedVarInfo) print(io, "VarInfo (") _show_varnames(io, vi) - print(io, "; logp: ", round(getlogp(vi); digits=3)) + print(io, "; accumulators: ") + # TODO(mhauru) This uses "text/plain" because we are doing quite a condensed repretation + # of vi anyway. However, technically `show(io, x)` should give full details of x and + # preferably output valid Julia code. + show(io, MIME"text/plain"(), getaccs(vi)) return print(io, ")") end diff --git a/test/accumulators.jl b/test/accumulators.jl new file mode 100644 index 000000000..36bb95e46 --- /dev/null +++ b/test/accumulators.jl @@ -0,0 +1,176 @@ +module AccumulatorTests + +using Test +using Distributions +using DynamicPPL +using DynamicPPL: + AccumulatorTuple, + LogLikelihoodAccumulator, + LogPriorAccumulator, + NumProduceAccumulator, + accumulate_assume!!, + accumulate_observe!!, + combine, + convert_eltype, + getacc, + increment, + map_accumulator, + setacc!!, + split + +@testset "accumulators" begin + @testset "individual accumulator types" begin + @testset "constructors" begin + @test LogPriorAccumulator(0.0) == + LogPriorAccumulator() == + LogPriorAccumulator{Float64}() == + LogPriorAccumulator{Float64}(0.0) == + zero(LogPriorAccumulator(1.0)) + @test LogLikelihoodAccumulator(0.0) == + LogLikelihoodAccumulator() == + LogLikelihoodAccumulator{Float64}() == + LogLikelihoodAccumulator{Float64}(0.0) == + zero(LogLikelihoodAccumulator(1.0)) + @test NumProduceAccumulator(0) == + NumProduceAccumulator() == + NumProduceAccumulator{Int}() == + NumProduceAccumulator{Int}(0) == + zero(NumProduceAccumulator(1)) + end + + @testset "addition and incrementation" begin + @test LogPriorAccumulator(1.0f0) + LogPriorAccumulator(1.0f0) == + LogPriorAccumulator(2.0f0) + @test LogPriorAccumulator(1.0) + LogPriorAccumulator(1.0f0) == + LogPriorAccumulator(2.0) + @test LogLikelihoodAccumulator(1.0f0) + LogLikelihoodAccumulator(1.0f0) == + LogLikelihoodAccumulator(2.0f0) + @test LogLikelihoodAccumulator(1.0) + LogLikelihoodAccumulator(1.0f0) == + LogLikelihoodAccumulator(2.0) + @test increment(NumProduceAccumulator()) == NumProduceAccumulator(1) + @test increment(NumProduceAccumulator{UInt8}()) == + NumProduceAccumulator{UInt8}(1) + end + + @testset "split and combine" begin + for acc in [ + LogPriorAccumulator(1.0), + LogLikelihoodAccumulator(1.0), + NumProduceAccumulator(1), + LogPriorAccumulator(1.0f0), + LogLikelihoodAccumulator(1.0f0), + NumProduceAccumulator(UInt8(1)), + ] + @test combine(acc, split(acc)) == acc + end + end + + @testset "conversions" begin + @test convert(LogPriorAccumulator{Float32}, LogPriorAccumulator(1.0)) == + LogPriorAccumulator{Float32}(1.0f0) + @test convert( + LogLikelihoodAccumulator{Float32}, LogLikelihoodAccumulator(1.0) + ) == LogLikelihoodAccumulator{Float32}(1.0f0) + @test convert(NumProduceAccumulator{UInt8}, NumProduceAccumulator(1)) == + NumProduceAccumulator{UInt8}(1) + + @test convert_eltype(Float32, LogPriorAccumulator(1.0)) == + LogPriorAccumulator{Float32}(1.0f0) + @test convert_eltype(Float32, LogLikelihoodAccumulator(1.0)) == + LogLikelihoodAccumulator{Float32}(1.0f0) + end + + @testset "accumulate_assume" begin + val = 2.0 + logjac = pi + vn = @varname(x) + dist = Normal() + @test accumulate_assume!!(LogPriorAccumulator(1.0), val, logjac, vn, dist) == + LogPriorAccumulator(1.0 + logjac + logpdf(dist, val)) + @test accumulate_assume!!( + LogLikelihoodAccumulator(1.0), val, logjac, vn, dist + ) == LogLikelihoodAccumulator(1.0) + @test accumulate_assume!!(NumProduceAccumulator(1), val, logjac, vn, dist) == + NumProduceAccumulator(1) + end + + @testset "accumulate_observe" begin + right = Normal() + left = 2.0 + vn = @varname(x) + @test accumulate_observe!!(LogPriorAccumulator(1.0), right, left, vn) == + LogPriorAccumulator(1.0) + @test accumulate_observe!!(LogLikelihoodAccumulator(1.0), right, left, vn) == + LogLikelihoodAccumulator(1.0 + logpdf(right, left)) + @test accumulate_observe!!(NumProduceAccumulator(1), right, left, vn) == + NumProduceAccumulator(2) + end + end + + @testset "accumulator tuples" begin + # Some accumulators we'll use for testing + lp_f64 = LogPriorAccumulator(1.0) + lp_f32 = LogPriorAccumulator(1.0f0) + ll_f64 = LogLikelihoodAccumulator(1.0) + ll_f32 = LogLikelihoodAccumulator(1.0f0) + np_i64 = NumProduceAccumulator(1) + + @testset "constructors" begin + @test AccumulatorTuple(lp_f64, ll_f64) == AccumulatorTuple((lp_f64, ll_f64)) + # Names in NamedTuple arguments are ignored + @test AccumulatorTuple((; a=lp_f64)) == AccumulatorTuple(lp_f64) + + # Can't have two accumulators of the same type. + @test_throws "duplicate field name" AccumulatorTuple(lp_f64, lp_f64) + # Not even if their element types differ. + @test_throws "duplicate field name" AccumulatorTuple(lp_f64, lp_f32) + end + + @testset "basic operations" begin + at_all64 = AccumulatorTuple(lp_f64, ll_f64, np_i64) + + @test at_all64[:LogPrior] == lp_f64 + @test at_all64[:LogLikelihood] == ll_f64 + @test at_all64[:NumProduce] == np_i64 + + @test haskey(AccumulatorTuple(np_i64), Val(:NumProduce)) + @test ~haskey(AccumulatorTuple(np_i64), Val(:LogPrior)) + @test length(AccumulatorTuple(lp_f64, ll_f64, np_i64)) == 3 + @test keys(at_all64) == (:LogPrior, :LogLikelihood, :NumProduce) + @test collect(at_all64) == [lp_f64, ll_f64, np_i64] + + # Replace the existing LogPriorAccumulator + @test setacc!!(at_all64, lp_f32)[:LogPrior] == lp_f32 + # Check that setacc!! didn't modify the original + @test at_all64 == AccumulatorTuple(lp_f64, ll_f64, np_i64) + # Add a new accumulator type. + @test setacc!!(AccumulatorTuple(lp_f64), ll_f64) == + AccumulatorTuple(lp_f64, ll_f64) + + @test getacc(at_all64, Val(:LogPrior)) == lp_f64 + end + + @testset "map_accumulator(s)!!" begin + # map over all accumulators + accs = AccumulatorTuple(lp_f32, ll_f32) + @test map(zero, accs) == AccumulatorTuple( + LogPriorAccumulator(0.0f0), LogLikelihoodAccumulator(0.0f0) + ) + # Test that the original wasn't modified. + @test accs == AccumulatorTuple(lp_f32, ll_f32) + + # A map with a closure that changes the types of the accumulators. + @test map(acc -> convert_eltype(Float64, acc), accs) == + AccumulatorTuple(LogPriorAccumulator(1.0), LogLikelihoodAccumulator(1.0)) + + # only apply to a particular accumulator + @test map_accumulator(zero, accs, Val(:LogLikelihood)) == + AccumulatorTuple(lp_f32, LogLikelihoodAccumulator(0.0f0)) + @test map_accumulator( + acc -> convert_eltype(Float64, acc), accs, Val(:LogLikelihood) + ) == AccumulatorTuple(lp_f32, LogLikelihoodAccumulator(1.0)) + end + end +end + +end diff --git a/test/ad.jl b/test/ad.jl index c34624f5b..48dffeadb 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,4 +1,5 @@ using DynamicPPL: LogDensityFunction +using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest @testset "Automatic differentiation" begin # Used as the ground truth that others are compared against. @@ -31,9 +32,10 @@ using DynamicPPL: LogDensityFunction linked_varinfo = DynamicPPL.link(varinfo, m) f = LogDensityFunction(m, linked_varinfo) x = DynamicPPL.getparams(f) + # Calculate reference logp + gradient of logp using ForwardDiff - ref_ldf = LogDensityFunction(m, linked_varinfo; adtype=ref_adtype) - ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x) + ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) + ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual @testset "$adtype" for adtype in test_adtypes @info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype" @@ -50,24 +52,24 @@ using DynamicPPL: LogDensityFunction if is_mooncake && is_1_11 && is_svi_vnv # https://github.com/compintell/Mooncake.jl/issues/470 @test_throws ArgumentError DynamicPPL.LogDensityFunction( - ref_ldf, adtype + m, linked_varinfo; adtype=adtype ) elseif is_mooncake && is_1_10 && is_svi_vnv # TODO: report upstream @test_throws UndefRefError DynamicPPL.LogDensityFunction( - ref_ldf, adtype + m, linked_varinfo; adtype=adtype ) elseif is_mooncake && is_1_10 && is_svi_od # TODO: report upstream @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction( - ref_ldf, adtype + m, linked_varinfo; adtype=adtype ) else - @test DynamicPPL.TestUtils.AD.run_ad( + @test run_ad( m, adtype; varinfo=linked_varinfo, - expected_value_and_grad=(ref_logp, ref_grad), + test=WithExpectedResult(ref_logp, ref_grad), ) isa Any end end @@ -110,9 +112,8 @@ using DynamicPPL: LogDensityFunction # Compiling the ReverseDiff tape used to fail here spl = Sampler(MyEmptyAlg()) vi = VarInfo(model) - ldf = LogDensityFunction( - model, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true) - ) + sampling_model = contextualize(model, SamplingContext(model.context)) + ldf = LogDensityFunction(sampling_model, vi; adtype=AutoReverseDiff(; compile=true)) @test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any end diff --git a/test/compiler.jl b/test/compiler.jl index 58e8c3efc..97121715a 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -185,36 +185,33 @@ module Issue537 end @model function testmodel_missing3(x) x[1] ~ Bernoulli(0.5) global varinfo_ = __varinfo__ - global sampler_ = __context__.sampler global model_ = __model__ - global context_ = __context__ - global rng_ = __context__.rng - global lp = getlogp(__varinfo__) + global lp = getlogjoint(__varinfo__) return x end model = testmodel_missing3([1.0]) varinfo = VarInfo(model) - @test getlogp(varinfo) == lp + @test getlogjoint(varinfo) == lp @test varinfo_ isa AbstractVarInfo - @test model_ === model - @test context_ isa SamplingContext - @test rng_ isa Random.AbstractRNG + # During the model evaluation, its context is wrapped in a + # SamplingContext, so `model_` is not going to be equal to `model`. + # We can still check equality of `f` though. + @test model_.f === model.f + @test model_.context isa SamplingContext + @test model_.context.rng isa Random.AbstractRNG # disable warnings @model function testmodel_missing4(x) x[1] ~ Bernoulli(0.5) global varinfo_ = __varinfo__ - global sampler_ = __context__.sampler global model_ = __model__ - global context_ = __context__ - global rng_ = __context__.rng - global lp = getlogp(__varinfo__) + global lp = getlogjoint(__varinfo__) return x end false lpold = lp model = testmodel_missing4([1.0]) varinfo = VarInfo(model) - @test getlogp(varinfo) == lp == lpold + @test getlogjoint(varinfo) == lp == lpold # test DPPL#61 @model function testmodel_missing5(z) @@ -333,14 +330,14 @@ module Issue537 end function makemodel(p) @model function testmodel(x) x[1] ~ Bernoulli(p) - global lp = getlogp(__varinfo__) + global lp = getlogjoint(__varinfo__) return x end return testmodel end model = makemodel(0.5)([1.0]) varinfo = VarInfo(model) - @test getlogp(varinfo) == lp + @test getlogjoint(varinfo) == lp end @testset "user-defined variable name" begin @model f1() = x ~ NamedDist(Normal(), :y) @@ -364,9 +361,9 @@ module Issue537 end # TODO(torfjelde): We need conditioning for `Dict`. @test_broken f2_c() == 1 @test_broken f3_c() == 1 - @test_broken getlogp(VarInfo(f1_c)) == - getlogp(VarInfo(f2_c)) == - getlogp(VarInfo(f3_c)) + @test_broken getlogjoint(VarInfo(f1_c)) == + getlogjoint(VarInfo(f2_c)) == + getlogjoint(VarInfo(f3_c)) end @testset "custom tilde" begin @model demo() = begin @@ -601,13 +598,13 @@ module Issue537 end # an attempt at a `NamedTuple` of the form `(x = 1, __varinfo__)`. @model empty_model() = return x = 1 empty_vi = VarInfo() - retval_and_vi = DynamicPPL.evaluate!!(empty_model(), empty_vi, SamplingContext()) + retval_and_vi = DynamicPPL.evaluate_and_sample!!(empty_model(), empty_vi) @test retval_and_vi isa Tuple{Int,typeof(empty_vi)} # Even if the return-value is `AbstractVarInfo`, we should return # a `Tuple` with `AbstractVarInfo` in the second component too. @model demo() = return __varinfo__ - retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext()) + retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) @test svi == SimpleVarInfo() if Threads.nthreads() > 1 @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} @@ -623,11 +620,11 @@ module Issue537 end f(x) = return x^2 return f(1.0) end - retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext()) + retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) @test retval isa Float64 @model demo() = x ~ Normal() - retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext()) + retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) # Return-value when using `to_submodel` @model inner() = x ~ Normal() @@ -732,10 +729,10 @@ module Issue537 end y := 100 + x return (; x, y) end - @model function demo_tracked_submodel() + @model function demo_tracked_submodel_no_prefix() return vals ~ to_submodel(demo_tracked(), false) end - for model in [demo_tracked(), demo_tracked_submodel()] + for model in [demo_tracked(), demo_tracked_submodel_no_prefix()] # Make sure it's runnable and `y` is present in the return-value. @test model() isa NamedTuple{(:x, :y)} @@ -756,6 +753,33 @@ module Issue537 end @test haskey(values, @varname(x)) @test !haskey(values, @varname(y)) end + + @model function demo_tracked_return_x() + x ~ Normal() + y := 100 + x + return x + end + @model function demo_tracked_submodel_prefix() + return a ~ to_submodel(demo_tracked_return_x()) + end + @model function demo_tracked_subsubmodel_prefix() + return b ~ to_submodel(demo_tracked_submodel_prefix()) + end + # As above, but the variables should now have their names prefixed with `b.a`. + model = demo_tracked_subsubmodel_prefix() + varinfo = VarInfo(model) + @test haskey(varinfo, @varname(b.a.x)) + @test length(keys(varinfo)) == 1 + + values = values_as_in_model(model, true, deepcopy(varinfo)) + @test haskey(values, @varname(b.a.x)) + @test haskey(values, @varname(b.a.y)) + + # And if include_colon_eq is set to `false`, then `values` should + # only contain `x`. + values = values_as_in_model(model, false, deepcopy(varinfo)) + @test haskey(values, @varname(b.a.x)) + @test length(keys(varinfo)) == 1 end @testset "signature parsing + TypeWrap" begin diff --git a/test/context_implementations.jl b/test/context_implementations.jl index 0ec88c07c..e16b2dc96 100644 --- a/test/context_implementations.jl +++ b/test/context_implementations.jl @@ -5,12 +5,12 @@ μ ~ MvNormal(zeros(2), 4 * I) z = Vector{Int}(undef, length(x)) z ~ product_distribution(Categorical.(fill([0.5, 0.5], length(x)))) - for i in 1:length(x) + for i in eachindex(x) x[i] ~ Normal(μ[z[i]], 0.1) end end - test([1, 1, -1])(VarInfo(), SampleFromPrior(), LikelihoodContext()) + test([1, 1, -1])(VarInfo()) end @testset "dot tilde with varying sizes" begin @@ -18,13 +18,14 @@ @model function test(x, size) y = Array{Float64,length(size)}(undef, size...) y .~ Normal(x) - return y, getlogp(__varinfo__) + return y end for ysize in ((2,), (2, 3), (2, 3, 4)) x = randn() model = test(x, ysize) - y, lp = model() + y = model() + lp = logjoint(model, (; y=y)) @test lp ≈ sum(logpdf.(Normal.(x), y)) ys = [first(model()) for _ in 1:10_000] diff --git a/test/contexts.jl b/test/contexts.jl index 1ba099a37..597ab736c 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -9,7 +9,6 @@ using DynamicPPL: NodeTrait, IsLeaf, IsParent, - PointwiseLogdensityContext, contextual_isassumption, FixedContext, ConditionContext, @@ -47,18 +46,11 @@ Base.IteratorSize(::Type{<:AbstractContext}) = Base.SizeUnknown() Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "contexts.jl" begin - child_contexts = Dict( + contexts = Dict( :default => DefaultContext(), - :prior => PriorContext(), - :likelihood => LikelihoodContext(), - ) - - parent_contexts = Dict( :testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()), :sampling => SamplingContext(), - :minibatch => MiniBatchContext(DefaultContext(), 0.0), :prefix => PrefixContext(@varname(x)), - :pointwiselogdensity => PointwiseLogdensityContext(), :condition1 => ConditionContext((x=1.0,)), :condition2 => ConditionContext( (x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,))) @@ -70,8 +62,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() :condition4 => ConditionContext((x=[1.0, missing],)), ) - contexts = merge(child_contexts, parent_contexts) - @testset "$(name)" for (name, context) in contexts @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS DynamicPPL.TestUtils.test_context(context, model) @@ -164,7 +154,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test DynamicPPL.prefix(ctx2, vn) == @varname(a.x[1]) ctx3 = PrefixContext(@varname(b), ctx2) @test DynamicPPL.prefix(ctx3, vn) == @varname(b.a.x[1]) - ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3) + ctx4 = DynamicPPL.SamplingContext(ctx3) @test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1]) end @@ -194,9 +184,10 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS prefix_vn = @varname(my_prefix) context = DynamicPPL.PrefixContext(prefix_vn, SamplingContext()) + sampling_model = contextualize(model, context) # Sample with the context. varinfo = DynamicPPL.VarInfo() - DynamicPPL.evaluate!!(model, varinfo, context) + DynamicPPL.evaluate!!(sampling_model, varinfo) # Extract the resulting varnames vns_actual = Set(keys(varinfo)) @@ -235,7 +226,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # Values from outer context should override inner one ctx1 = ConditionContext(n1, ConditionContext(n2)) @test ctx1.values == (x=1, y=2) - # Check that the two ConditionContexts are collapsed + # Check that the two ConditionContexts are collapsed @test childcontext(ctx1) isa DefaultContext # Then test the nesting the other way round ctx2 = ConditionContext(n2, ConditionContext(n1)) diff --git a/test/debug_utils.jl b/test/debug_utils.jl index d2269e089..8279ac51a 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -1,7 +1,7 @@ @testset "check_model" begin @testset "context interface" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - context = DynamicPPL.DebugUtils.DebugContext(model) + context = DynamicPPL.DebugUtils.DebugContext() DynamicPPL.TestUtils.test_context(context, model) end end @@ -35,9 +35,7 @@ buggy_model = buggy_demo_model() @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model( - buggy_model; context=SamplingContext(), record_varinfo=false - ) + issuccess = check_model(buggy_model; record_varinfo=false) @test !issuccess @test_throws ErrorException check_model(buggy_model; error_on_failure=true) end @@ -81,9 +79,7 @@ buggy_model = buggy_subsumes_demo_model() @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model( - buggy_model; context=SamplingContext(), record_varinfo=false - ) + issuccess = check_model(buggy_model; record_varinfo=false) @test !issuccess @test_throws ErrorException check_model(buggy_model; error_on_failure=true) end @@ -98,9 +94,7 @@ buggy_model = buggy_subsumes_demo_model() @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model( - buggy_model; context=SamplingContext(), record_varinfo=false - ) + issuccess = check_model(buggy_model; record_varinfo=false) @test !issuccess @test_throws ErrorException check_model(buggy_model; error_on_failure=true) end @@ -115,9 +109,7 @@ buggy_model = buggy_subsumes_demo_model() @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model( - buggy_model; context=SamplingContext(), record_varinfo=false - ) + issuccess = check_model(buggy_model; record_varinfo=false) @test !issuccess @test_throws ErrorException check_model(buggy_model; error_on_failure=true) end diff --git a/test/deprecated.jl b/test/deprecated.jl deleted file mode 100644 index 500d3eb7f..000000000 --- a/test/deprecated.jl +++ /dev/null @@ -1,57 +0,0 @@ -@testset "deprecated" begin - @testset "@submodel" begin - @testset "is deprecated" begin - @model inner() = x ~ Normal() - @model outer() = @submodel x = inner() - @test_logs( - ( - :warn, - "`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.", - ), - outer()() - ) - - @model outer_with_prefix() = @submodel prefix = "sub" x = inner() - @test_logs( - ( - :warn, - "`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.", - ), - outer_with_prefix()() - ) - end - - @testset "prefixing still works correctly" begin - @model inner() = x ~ Normal() - @model function outer() - a = @submodel inner() - b = @submodel prefix = "sub" inner() - return a, b - end - @test outer()() isa Tuple{Float64,Float64} - vi = VarInfo(outer()) - @test @varname(x) in keys(vi) - @test @varname(sub.x) in keys(vi) - end - - @testset "logp is still accumulated properly" begin - @model inner_assume() = x ~ Normal() - @model inner_observe(x, y) = y ~ Normal(x) - @model function outer(b) - a = @submodel inner_assume() - @submodel inner_observe(a, b) - end - y_val = 1.0 - model = outer(y_val) - @test model() == y_val - - x_val = 1.5 - vi = VarInfo(outer(y_val)) - DynamicPPL.setindex!!(vi, x_val, @varname(x)) - @test logprior(model, vi) ≈ logpdf(Normal(), x_val) - @test loglikelihood(model, vi) ≈ logpdf(Normal(x_val), y_val) - @test logjoint(model, vi) ≈ - logpdf(Normal(), x_val) + logpdf(Normal(x_val), y_val) - end - end -end diff --git a/test/ext/DynamicPPLForwardDiffExt.jl b/test/ext/DynamicPPLForwardDiffExt.jl index 73a0510e9..44db66296 100644 --- a/test/ext/DynamicPPLForwardDiffExt.jl +++ b/test/ext/DynamicPPLForwardDiffExt.jl @@ -14,17 +14,16 @@ using Test: @test, @testset @model f() = x ~ MvNormal(zeros(MODEL_SIZE), I) model = f() varinfo = VarInfo(model) - context = DefaultContext() @testset "Chunk size setting" for chunksize in (nothing, 0) base_adtype = AutoForwardDiff(; chunksize=chunksize) - new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo, context) + new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo) @test new_adtype isa AutoForwardDiff{MODEL_SIZE} end @testset "Tag setting" begin base_adtype = AutoForwardDiff() - new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo, context) + new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo) @test new_adtype.tag isa ForwardDiff.Tag{DynamicPPL.DynamicPPLTag} end end diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index 86329a51d..6737cf056 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -62,6 +62,7 @@ @testset "demo models" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + sampling_model = contextualize(model, SamplingContext(model.context)) # Use debug logging below. varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model) # Check that the inferred varinfo is indeed suitable for evaluation and sampling @@ -71,7 +72,7 @@ JET.test_call(f_eval, argtypes_eval) f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, varinfo, DynamicPPL.SamplingContext() + sampling_model, varinfo ) JET.test_call(f_sample, argtypes_sample) # For our demo models, they should all result in typed. @@ -85,7 +86,7 @@ ) JET.test_call(f_eval, argtypes_eval) f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, typed_vi, DynamicPPL.SamplingContext() + sampling_model, typed_vi ) JET.test_call(f_sample, argtypes_sample) end diff --git a/test/independence.jl b/test/independence.jl deleted file mode 100644 index a4a834a61..000000000 --- a/test/independence.jl +++ /dev/null @@ -1,11 +0,0 @@ -@testset "Turing independence" begin - @model coinflip(y) = begin - p ~ Beta(1, 1) - N = length(y) - for i in 1:N - y[i] ~ Bernoulli(p) - end - end - model = coinflip([1, 1, 0]) - model(SampleFromPrior(), LikelihoodContext()) -end diff --git a/test/integration/enzyme/main.jl b/test/integration/enzyme/main.jl index 62b7ace4d..ea4ec497d 100644 --- a/test/integration/enzyme/main.jl +++ b/test/integration/enzyme/main.jl @@ -2,12 +2,14 @@ using DynamicPPL.TestUtils: DEMO_MODELS using DynamicPPL.TestUtils.AD: run_ad using ADTypes: AutoEnzyme using Test: @test, @testset -import Enzyme: set_runtime_activity, Forward, Reverse +import Enzyme: set_runtime_activity, Forward, Reverse, Const using ForwardDiff: ForwardDiff # run_ad uses FD for correctness test ADTYPES = Dict( - "EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward)), - "EnzymeReverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse)), + "EnzymeForward" => + AutoEnzyme(; mode=set_runtime_activity(Forward), function_annotation=Const), + "EnzymeReverse" => + AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const), ) @testset "$ad_key" for (ad_key, ad_type) in ADTYPES diff --git a/test/linking.jl b/test/linking.jl index d424a9c2d..b0c2dcb5c 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -78,14 +78,14 @@ end vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(m),)) @testset "$(short_varinfo_name(vi))" for vi in vis # Evaluate once to ensure we have `logp` value. - vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) + vi = last(DynamicPPL.evaluate!!(model, vi)) vi_linked = if mutable DynamicPPL.link!!(deepcopy(vi), model) else DynamicPPL.link(vi, model) end # Difference should just be the log-absdet-jacobian "correction". - @test DynamicPPL.getlogp(vi) - DynamicPPL.getlogp(vi_linked) ≈ log(2) + @test DynamicPPL.getlogjoint(vi) - DynamicPPL.getlogjoint(vi_linked) ≈ log(2) @test vi_linked[@varname(m), dist] == LowerTriangular(vi[@varname(m), dist]) # Linked one should be working with a lower-dimensional representation. @test length(vi_linked[:]) < length(vi[:]) @@ -98,7 +98,7 @@ end end @test length(vi_invlinked[:]) == length(vi[:]) @test vi_invlinked[@varname(m), dist] ≈ LowerTriangular(vi[@varname(m), dist]) - @test DynamicPPL.getlogp(vi_invlinked) ≈ DynamicPPL.getlogp(vi) + @test DynamicPPL.getlogjoint(vi_invlinked) ≈ DynamicPPL.getlogjoint(vi) end end @@ -130,7 +130,7 @@ end end @test length(vi_linked[:]) == d * (d - 1) ÷ 2 # Should now include the log-absdet-jacobian correction. - @test !(getlogp(vi_linked) ≈ lp) + @test !(getlogjoint(vi_linked) ≈ lp) # Invlinked. vi_invlinked = if mutable DynamicPPL.invlink!!(deepcopy(vi_linked), model) @@ -138,7 +138,7 @@ end DynamicPPL.invlink(vi_linked, model) end @test length(vi_invlinked[:]) == d^2 - @test getlogp(vi_invlinked) ≈ lp + @test getlogjoint(vi_invlinked) ≈ lp end end end @@ -164,7 +164,7 @@ end end @test length(vi_linked[:]) == d - 1 # Should now include the log-absdet-jacobian correction. - @test !(getlogp(vi_linked) ≈ lp) + @test !(getlogjoint(vi_linked) ≈ lp) # Invlinked. vi_invlinked = if mutable DynamicPPL.invlink!!(deepcopy(vi_linked), model) @@ -172,7 +172,7 @@ end DynamicPPL.invlink(vi_linked, model) end @test length(vi_invlinked[:]) == d - @test getlogp(vi_invlinked) ≈ lp + @test getlogjoint(vi_invlinked) ≈ lp end end end diff --git a/test/model.jl b/test/model.jl index 829ddd302..daa3cc743 100644 --- a/test/model.jl +++ b/test/model.jl @@ -41,7 +41,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() m = vi[@varname(m)] # extract log pdf of variable object - lp = getlogp(vi) + lp = getlogjoint(vi) # log prior probability lprior = logprior(model, vi) @@ -162,12 +162,12 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() for i in 1:10 Random.seed!(100 + i) vi = VarInfo() - model(Random.default_rng(), vi, sampler) + DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler) vals = vi[:] Random.seed!(100 + i) vi = VarInfo() - model(Random.default_rng(), vi, sampler) + DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler) @test vi[:] == vals end end @@ -223,7 +223,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() # Second component of return-value of `evaluate!!` should # be a `DynamicPPL.AbstractVarInfo`. - evaluate_retval = DynamicPPL.evaluate!!(model, vi, DefaultContext()) + evaluate_retval = DynamicPPL.evaluate!!(model, vi) @test evaluate_retval[2] isa DynamicPPL.AbstractVarInfo # Should not return `AbstractVarInfo` when we call the model. @@ -332,11 +332,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test logjoint(model, x) != DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...) # Ensure `varnames` is implemented. - vi = last( - DynamicPPL.evaluate!!( - model, SimpleVarInfo(OrderedDict()), SamplingContext() - ), - ) + vi = last(DynamicPPL.evaluate_and_sample!!(model, SimpleVarInfo(OrderedDict()))) @test all(collect(keys(vi)) .== DynamicPPL.TestUtils.varnames(model)) # Ensure `posterior_mean` is implemented. @test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x) @@ -397,7 +393,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() models_to_test = [ DynamicPPL.TestUtils.DEMO_MODELS..., DynamicPPL.TestUtils.demo_lkjchol(2) ] - context = DefaultContext() @testset "$(model.f)" for model in models_to_test vns = DynamicPPL.TestUtils.varnames(model) example_values = DynamicPPL.TestUtils.rand_prior_true(model) @@ -407,13 +402,13 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos @test begin - @inferred(DynamicPPL.evaluate!!(model, varinfo, context)) + @inferred(DynamicPPL.evaluate!!(model, varinfo)) true end varinfo_linked = DynamicPPL.link(varinfo, model) @test begin - @inferred(DynamicPPL.evaluate!!(model, varinfo_linked, context)) + @inferred(DynamicPPL.evaluate!!(model, varinfo_linked)) true end end @@ -492,9 +487,9 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos varinfo_linked = DynamicPPL.link(varinfo, model) varinfo_linked_result = last( - DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked), DefaultContext()) + DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked)) ) - @test getlogp(varinfo_linked) ≈ getlogp(varinfo_linked_result) + @test getlogjoint(varinfo_linked) ≈ getlogjoint(varinfo_linked_result) end end @@ -596,7 +591,10 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() xs_train = 1:0.1:10 ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train)) m_lin_reg = linear_reg(xs_train, ys_train) - chain = [evaluate!!(m_lin_reg)[2] for _ in 1:10000] + chain = [ + last(DynamicPPL.evaluate_and_sample!!(m_lin_reg, VarInfo())) for + _ in 1:10000 + ] # chain is generated from the prior @test mean([chain[i][@varname(β)] for i in eachindex(chain)]) ≈ 1.0 atol = 0.1 diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 61c842638..cfb222b66 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -1,6 +1,4 @@ @testset "logdensities_likelihoods.jl" begin - mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2) - mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS example_values = DynamicPPL.TestUtils.rand_prior_true(model) @@ -37,11 +35,6 @@ lps = pointwise_logdensities(model, vi) logp = sum(sum, values(lps)) @test logp ≈ (logprior_true + loglikelihood_true) - - # Test that modifications of Setup are picked up - lps = pointwise_logdensities(model, vi, mod_ctx2) - logp = sum(sum, values(lps)) - @test logp ≈ (logprior_true + loglikelihood_true) * 1.2 * 1.4 end end diff --git a/test/runtests.jl b/test/runtests.jl index 997a41641..c60c06786 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -55,13 +55,13 @@ include("test_util.jl") include("Aqua.jl") end include("utils.jl") + include("accumulators.jl") include("compiler.jl") include("varnamedvector.jl") include("varinfo.jl") include("simple_varinfo.jl") include("model.jl") include("sampler.jl") - include("independence.jl") include("distribution_wrappers.jl") include("logdensityfunction.jl") include("linking.jl") @@ -72,7 +72,6 @@ include("test_util.jl") include("context_implementations.jl") include("threadsafe.jl") include("debug_utils.jl") - include("deprecated.jl") include("submodels.jl") include("bijector.jl") end diff --git a/test/sampler.jl b/test/sampler.jl index 8c4f1ed96..fe9fd331a 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -84,7 +84,7 @@ let inits = (; p=0.2) chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test chain[1].metadata.p.vals == [0.2] - @test getlogp(chain[1]) == lptrue + @test getlogjoint(chain[1]) == lptrue # parallel sampling chains = sample( @@ -98,7 +98,7 @@ ) for c in chains @test c[1].metadata.p.vals == [0.2] - @test getlogp(c[1]) == lptrue + @test getlogjoint(c[1]) == lptrue end end @@ -113,7 +113,7 @@ chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test chain[1].metadata.s.vals == [4] @test chain[1].metadata.m.vals == [-1] - @test getlogp(chain[1]) == lptrue + @test getlogjoint(chain[1]) == lptrue # parallel sampling chains = sample( @@ -128,7 +128,7 @@ for c in chains @test c[1].metadata.s.vals == [4] @test c[1].metadata.m.vals == [-1] - @test getlogp(c[1]) == lptrue + @test getlogjoint(c[1]) == lptrue end end diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 380c24e7d..e300c651e 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -2,12 +2,12 @@ @testset "constructor & indexing" begin @testset "NamedTuple" begin svi = SimpleVarInfo(; m=1.0) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test !haskey(svi, @varname(m[1])) svi = SimpleVarInfo(; m=[1.0]) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test haskey(svi, @varname(m[1])) @test !haskey(svi, @varname(m[2])) @@ -21,20 +21,21 @@ @test !haskey(svi, @varname(m.a.b)) svi = SimpleVarInfo{Float32}(; m=1.0) - @test getlogp(svi) isa Float32 + @test getlogjoint(svi) isa Float32 - svi = SimpleVarInfo((m=1.0,), 1.0) - @test getlogp(svi) == 1.0 + svi = SimpleVarInfo((m=1.0,)) + svi = accloglikelihood!!(svi, 1.0) + @test getlogjoint(svi) == 1.0 end @testset "Dict" begin svi = SimpleVarInfo(Dict(@varname(m) => 1.0)) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test !haskey(svi, @varname(m[1])) svi = SimpleVarInfo(Dict(@varname(m) => [1.0])) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test haskey(svi, @varname(m[1])) @test !haskey(svi, @varname(m[2])) @@ -59,12 +60,12 @@ @testset "VarNamedVector" begin svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => 1.0)) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test !haskey(svi, @varname(m[1])) svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => [1.0])) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test haskey(svi, @varname(m[1])) @test !haskey(svi, @varname(m[2])) @@ -97,12 +98,11 @@ for vn in DynamicPPL.TestUtils.varnames(model) vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) end - vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) - lp_orig = getlogp(vi) + vi = last(DynamicPPL.evaluate!!(model, vi)) # `link!!` vi_linked = link!!(deepcopy(vi), model) - lp_linked = getlogp(vi_linked) + lp_linked = getlogjoint(vi_linked) values_unconstrained, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( model, values_constrained... ) @@ -113,7 +113,7 @@ # `invlink!!` vi_invlinked = invlink!!(deepcopy(vi_linked), model) - lp_invlinked = getlogp(vi_invlinked) + lp_invlinked = getlogjoint(vi_invlinked) lp_invlinked_true = DynamicPPL.TestUtils.logjoint_true( model, values_constrained... ) @@ -152,13 +152,13 @@ # DynamicPPL.settrans!!(deepcopy(svi_dict), true), # DynamicPPL.settrans!!(deepcopy(svi_vnv), true), ) - # RandOM seed is set in each `@testset`, so we need to sample + # Random seed is set in each `@testset`, so we need to sample # a new realization for `m` here. retval = model() ### Sampling ### # Sample a new varinfo! - _, svi_new = DynamicPPL.evaluate!!(model, svi, SamplingContext()) + _, svi_new = DynamicPPL.evaluate_and_sample!!(model, svi) # Realization for `m` should be different wp. 1. for vn in DynamicPPL.TestUtils.varnames(model) @@ -166,7 +166,7 @@ end # Logjoint should be non-zero wp. 1. - @test getlogp(svi_new) != 0 + @test getlogjoint(svi_new) != 0 ### Evaluation ### values_eval_constrained = DynamicPPL.TestUtils.rand_prior_true(model) @@ -201,7 +201,7 @@ svi_eval = DynamicPPL.setindex!!(svi_eval, get(values_eval, vn), vn) end - # Reset the logp field. + # Reset the logp accumulators. svi_eval = DynamicPPL.resetlogp!!(svi_eval) # Compute `logjoint` using the varinfo. @@ -226,9 +226,9 @@ # Initialize. svi_nt = DynamicPPL.settrans!!(SimpleVarInfo(), true) - svi_nt = last(DynamicPPL.evaluate!!(model, svi_nt, SamplingContext())) + svi_nt = last(DynamicPPL.evaluate_and_sample!!(model, svi_nt)) svi_vnv = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - svi_vnv = last(DynamicPPL.evaluate!!(model, svi_vnv, SamplingContext())) + svi_vnv = last(DynamicPPL.evaluate_and_sample!!(model, svi_vnv)) for svi in (svi_nt, svi_vnv) # Sample with large variations in unconstrained space. @@ -236,7 +236,7 @@ for vn in keys(svi) svi = DynamicPPL.setindex!!(svi, 10 * randn(), vn) end - retval, svi = DynamicPPL.evaluate!!(model, svi, DefaultContext()) + retval, svi = DynamicPPL.evaluate!!(model, svi) @test retval.m == svi[@varname(m)] # `m` is unconstrained @test retval.x ≠ svi[@varname(x)] # `x` is constrained depending on `m` @@ -250,7 +250,7 @@ end # `getlogp` should be equal to the logjoint with log-absdet-jac correction. - lp = getlogp(svi) + lp = getlogjoint(svi) # needs higher atol because of https://github.com/TuringLang/Bijectors.jl/issues/375 @test lp ≈ lp_true atol = 1.2e-5 end @@ -273,7 +273,7 @@ ) # Resulting varinfo should no longer be transformed. - vi_result = last(DynamicPPL.evaluate!!(model, deepcopy(vi), SamplingContext())) + vi_result = last(DynamicPPL.evaluate_and_sample!!(model, deepcopy(vi))) @test !DynamicPPL.istrans(vi_result) # Set the values to something that is out of domain if we're in constrained space. @@ -281,9 +281,7 @@ vi_linked = DynamicPPL.setindex!!(vi_linked, -rand(), vn) end - retval, vi_linked_result = DynamicPPL.evaluate!!( - model, deepcopy(vi_linked), DefaultContext() - ) + retval, vi_linked_result = DynamicPPL.evaluate!!(model, deepcopy(vi_linked)) @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≠ DynamicPPL.tovec(retval.s) # `s` is unconstrained in original @@ -306,7 +304,7 @@ DynamicPPL.tovec(retval_unconstrained.m) # The resulting varinfo should hold the correct logp. - lp = getlogp(vi_linked_result) + lp = getlogjoint(vi_linked_result) @test lp ≈ lp_true end end diff --git a/test/submodels.jl b/test/submodels.jl index e79eed2c3..d3a2f17e7 100644 --- a/test/submodels.jl +++ b/test/submodels.jl @@ -35,7 +35,7 @@ using Test @test model()[1] == x_val # Test that the logp was correctly set vi = VarInfo(model) - @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(a.y)]) + @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(a.y)]) # Check the keys @test Set(keys(VarInfo(model))) == Set([@varname(a.y)]) end @@ -67,7 +67,7 @@ using Test @test model()[1] == x_val # Test that the logp was correctly set vi = VarInfo(model) - @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(y)]) + @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(y)]) # Check the keys @test Set(keys(VarInfo(model))) == Set([@varname(y)]) end @@ -99,7 +99,7 @@ using Test @test model()[1] == x_val # Test that the logp was correctly set vi = VarInfo(model) - @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(b.y)]) + @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(b.y)]) # Check the keys @test Set(keys(VarInfo(model))) == Set([@varname(b.y)]) end @@ -148,7 +148,7 @@ using Test # No conditioning vi = VarInfo(h()) @test Set(keys(vi)) == Set([@varname(a.b.x), @varname(a.b.y)]) - @test getlogp(vi) == + @test getlogjoint(vi) == logpdf(Normal(), vi[@varname(a.b.x)]) + logpdf(Normal(), vi[@varname(a.b.y)]) @@ -174,7 +174,7 @@ using Test @testset "$name" for (name, model) in models vi = VarInfo(model) @test Set(keys(vi)) == Set([@varname(a.b.y)]) - @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)]) + @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)]) end end end diff --git a/test/threadsafe.jl b/test/threadsafe.jl index ededf78b0..24a738a78 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -4,9 +4,12 @@ threadsafe_vi = @inferred DynamicPPL.ThreadSafeVarInfo(vi) @test threadsafe_vi.varinfo === vi - @test threadsafe_vi.logps isa Vector{typeof(Ref(getlogp(vi)))} - @test length(threadsafe_vi.logps) == Threads.nthreads() * 2 - @test all(iszero(x[]) for x in threadsafe_vi.logps) + @test threadsafe_vi.accs_by_thread isa Vector{<:DynamicPPL.AccumulatorTuple} + @test length(threadsafe_vi.accs_by_thread) == Threads.nthreads() * 2 + expected_accs = DynamicPPL.AccumulatorTuple( + (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(vi))... + ) + @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) end # TODO: Add more tests of the public API @@ -14,23 +17,27 @@ vi = VarInfo(gdemo_default) threadsafe_vi = DynamicPPL.ThreadSafeVarInfo(vi) - lp = getlogp(vi) - @test getlogp(threadsafe_vi) == lp + lp = getlogjoint(vi) + @test getlogjoint(threadsafe_vi) == lp - acclogp!!(threadsafe_vi, 42) - @test threadsafe_vi.logps[Threads.threadid()][] == 42 - @test getlogp(vi) == lp - @test getlogp(threadsafe_vi) == lp + 42 + threadsafe_vi = DynamicPPL.acclogprior!!(threadsafe_vi, 42) + @test threadsafe_vi.accs_by_thread[Threads.threadid()][:LogPrior].logp == 42 + @test getlogjoint(vi) == lp + @test getlogjoint(threadsafe_vi) == lp + 42 - resetlogp!!(threadsafe_vi) - @test iszero(getlogp(vi)) - @test iszero(getlogp(threadsafe_vi)) - @test all(iszero(x[]) for x in threadsafe_vi.logps) + threadsafe_vi = resetlogp!!(threadsafe_vi) + @test iszero(getlogjoint(threadsafe_vi)) + expected_accs = DynamicPPL.AccumulatorTuple( + (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(threadsafe_vi.varinfo))... + ) + @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) - setlogp!!(threadsafe_vi, 42) - @test getlogp(vi) == 42 - @test getlogp(threadsafe_vi) == 42 - @test all(iszero(x[]) for x in threadsafe_vi.logps) + threadsafe_vi = setlogprior!!(threadsafe_vi, 42) + @test getlogjoint(threadsafe_vi) == 42 + expected_accs = DynamicPPL.AccumulatorTuple( + (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(threadsafe_vi.varinfo))... + ) + @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) end @testset "model" begin @@ -45,10 +52,11 @@ x[i] ~ Normal(x[i - 1], 1) end end + model = wthreads(x) vi = VarInfo() - wthreads(x)(vi) - lp_w_threads = getlogp(vi) + model(vi) + lp_w_threads = getlogjoint(vi) if Threads.nthreads() == 1 @test vi_ isa VarInfo else @@ -57,23 +65,19 @@ println("With `@threads`:") println(" default:") - @time wthreads(x)(vi) + @time model(vi) # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. - DynamicPPL.evaluate_threadsafe!!( - wthreads(x), - vi, - SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), - ) - @test getlogp(vi) ≈ lp_w_threads + sampling_model = contextualize(model, SamplingContext(model.context)) + DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) + @test getlogjoint(vi) ≈ lp_w_threads + # check that it's wrapped during the model evaluation @test vi_ isa DynamicPPL.ThreadSafeVarInfo + # ensure that it's unwrapped after evaluation finishes + @test vi isa VarInfo println(" evaluate_threadsafe!!:") - @time DynamicPPL.evaluate_threadsafe!!( - wthreads(x), - vi, - SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), - ) + @time DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) @model function wothreads(x) global vi_ = __varinfo__ @@ -82,10 +86,11 @@ x[i] ~ Normal(x[i - 1], 1) end end + model = wothreads(x) vi = VarInfo() - wothreads(x)(vi) - lp_wo_threads = getlogp(vi) + model(vi) + lp_wo_threads = getlogjoint(vi) if Threads.nthreads() == 1 @test vi_ isa VarInfo else @@ -94,24 +99,18 @@ println("Without `@threads`:") println(" default:") - @time wothreads(x)(vi) + @time model(vi) @test lp_w_threads ≈ lp_wo_threads # Ensure that we use `VarInfo`. - DynamicPPL.evaluate_threadunsafe!!( - wothreads(x), - vi, - SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), - ) - @test getlogp(vi) ≈ lp_w_threads + sampling_model = contextualize(model, SamplingContext(model.context)) + DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) + @test getlogjoint(vi) ≈ lp_w_threads @test vi_ isa VarInfo + @test vi isa VarInfo println(" evaluate_threadunsafe!!:") - @time DynamicPPL.evaluate_threadunsafe!!( - wothreads(x), - vi, - SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), - ) + @time DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) end end diff --git a/test/utils.jl b/test/utils.jl index 7a7338fa7..081e58d61 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,15 +1,34 @@ @testset "utils.jl" begin @testset "addlogprob!" begin @model function testmodel() - global lp_before = getlogp(__varinfo__) + global lp_before = getlogjoint(__varinfo__) @addlogprob!(42) - return global lp_after = getlogp(__varinfo__) + return global lp_after = getlogjoint(__varinfo__) end - model = testmodel() - varinfo = VarInfo(model) + varinfo = VarInfo(testmodel()) @test iszero(lp_before) - @test getlogp(varinfo) == lp_after == 42 + @test getlogjoint(varinfo) == lp_after == 42 + @test getloglikelihood(varinfo) == 42 + + @model function testmodel_nt() + global lp_before = getlogjoint(__varinfo__) + @addlogprob! (; logprior=(pi + 1), loglikelihood=42) + return global lp_after = getlogjoint(__varinfo__) + end + + varinfo = VarInfo(testmodel_nt()) + @test iszero(lp_before) + @test getlogjoint(varinfo) == lp_after == 42 + 1 + pi + @test getloglikelihood(varinfo) == 42 + @test getlogprior(varinfo) == pi + 1 + + @model function testmodel_nt2() + global lp_before = getlogjoint(__varinfo__) + llh_nt = (; loglikelihood=42) + @addlogprob! llh_nt + return global lp_after = getlogjoint(__varinfo__) + end end @testset "getargs_dottilde" begin diff --git a/test/varinfo.jl b/test/varinfo.jl index 444a88875..75868eb66 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -72,7 +72,7 @@ end function test_base(vi_original) vi = deepcopy(vi_original) - @test getlogp(vi) == 0 + @test getlogjoint(vi) == 0 @test isempty(vi[:]) vn = @varname x @@ -116,13 +116,25 @@ end @testset "get/set/acc/resetlogp" begin function test_varinfo_logp!(vi) - @test DynamicPPL.getlogp(vi) === 0.0 - vi = DynamicPPL.setlogp!!(vi, 1.0) - @test DynamicPPL.getlogp(vi) === 1.0 - vi = DynamicPPL.acclogp!!(vi, 1.0) - @test DynamicPPL.getlogp(vi) === 2.0 + @test DynamicPPL.getlogjoint(vi) === 0.0 + vi = DynamicPPL.setlogprior!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 1.0 + @test DynamicPPL.getloglikelihood(vi) === 0.0 + @test DynamicPPL.getlogjoint(vi) === 1.0 + vi = DynamicPPL.acclogprior!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 2.0 + @test DynamicPPL.getloglikelihood(vi) === 0.0 + @test DynamicPPL.getlogjoint(vi) === 2.0 + vi = DynamicPPL.setloglikelihood!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 2.0 + @test DynamicPPL.getloglikelihood(vi) === 1.0 + @test DynamicPPL.getlogjoint(vi) === 3.0 + vi = DynamicPPL.accloglikelihood!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 2.0 + @test DynamicPPL.getloglikelihood(vi) === 2.0 + @test DynamicPPL.getlogjoint(vi) === 4.0 vi = DynamicPPL.resetlogp!!(vi) - @test DynamicPPL.getlogp(vi) === 0.0 + @test DynamicPPL.getlogjoint(vi) === 0.0 end vi = VarInfo() @@ -133,6 +145,104 @@ end test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector())) end + @testset "accumulators" begin + @model function demo() + a ~ Normal() + b ~ Normal() + c ~ Normal() + d ~ Normal() + return nothing + end + + values = (; a=1.0, b=2.0, c=3.0, d=4.0) + lp_a = logpdf(Normal(), values.a) + lp_b = logpdf(Normal(), values.b) + lp_c = logpdf(Normal(), values.c) + lp_d = logpdf(Normal(), values.d) + m = demo() | (; c=values.c, d=values.d) + + vi = DynamicPPL.reset_num_produce!!( + DynamicPPL.unflatten(VarInfo(m), collect(values)) + ) + + vi = last(DynamicPPL.evaluate!!(m, deepcopy(vi))) + @test getlogprior(vi) == lp_a + lp_b + @test getloglikelihood(vi) == lp_c + lp_d + @test getlogp(vi) == (; logprior=lp_a + lp_b, loglikelihood=lp_c + lp_d) + @test getlogjoint(vi) == lp_a + lp_b + lp_c + lp_d + @test get_num_produce(vi) == 2 + @test begin + vi = acclogprior!!(vi, 1.0) + getlogprior(vi) == lp_a + lp_b + 1.0 + end + @test begin + vi = accloglikelihood!!(vi, 1.0) + getloglikelihood(vi) == lp_c + lp_d + 1.0 + end + @test begin + vi = setlogprior!!(vi, -1.0) + getlogprior(vi) == -1.0 + end + @test begin + vi = setloglikelihood!!(vi, -1.0) + getloglikelihood(vi) == -1.0 + end + @test begin + vi = setlogp!!(vi, (logprior=-3.0, loglikelihood=-3.0)) + getlogp(vi) == (; logprior=-3.0, loglikelihood=-3.0) + end + @test begin + vi = acclogp!!(vi, (logprior=1.0, loglikelihood=1.0)) + getlogp(vi) == (; logprior=-2.0, loglikelihood=-2.0) + end + @test getlogp(setlogp!!(vi, getlogp(vi))) == getlogp(vi) + + vi = last( + DynamicPPL.evaluate!!( + m, DynamicPPL.setaccs!!(deepcopy(vi), (LogPriorAccumulator(),)) + ), + ) + @test getlogprior(vi) == lp_a + lp_b + # need regex because 1.11 and 1.12 throw different errors (in 1.12 the + # missing field is surrounded by backticks) + @test_throws r"has no field `?LogLikelihood" getloglikelihood(vi) + @test_throws r"has no field `?LogLikelihood" getlogp(vi) + @test_throws r"has no field `?LogLikelihood" getlogjoint(vi) + @test_throws r"has no field `?NumProduce" get_num_produce(vi) + @test begin + vi = acclogprior!!(vi, 1.0) + getlogprior(vi) == lp_a + lp_b + 1.0 + end + @test begin + vi = setlogprior!!(vi, -1.0) + getlogprior(vi) == -1.0 + end + + vi = last( + DynamicPPL.evaluate!!( + m, DynamicPPL.setaccs!!(deepcopy(vi), (NumProduceAccumulator(),)) + ), + ) + # need regex because 1.11 and 1.12 throw different errors (in 1.12 the + # missing field is surrounded by backticks) + @test_throws r"has no field `?LogPrior" getlogprior(vi) + @test_throws r"has no field `?LogLikelihood" getloglikelihood(vi) + @test_throws r"has no field `?LogPrior" getlogp(vi) + @test_throws r"has no field `?LogPrior" getlogjoint(vi) + @test get_num_produce(vi) == 2 + + # Test evaluating without any accumulators. + vi = last(DynamicPPL.evaluate!!(m, DynamicPPL.setaccs!!(deepcopy(vi), ()))) + # need regex because 1.11 and 1.12 throw different errors (in 1.12 the + # missing field is surrounded by backticks) + @test_throws r"has no field `?LogPrior" getlogprior(vi) + @test_throws r"has no field `?LogLikelihood" getloglikelihood(vi) + @test_throws r"has no field `?LogPrior" getlogp(vi) + @test_throws r"has no field `?LogPrior" getlogjoint(vi) + @test_throws r"has no field `?NumProduce" get_num_produce(vi) + @test_throws r"has no field `?NumProduce" reset_num_produce!!(vi) + end + @testset "flags" begin # Test flag setting: # is_flagged, set_flag!, unset_flag! @@ -376,10 +486,17 @@ end end model = gdemo([1.0, 1.5], [2.0, 2.5]) - # Check that instantiating the model does not perform linking + # Check that instantiating the model using SampleFromUniform does not + # perform linking + # Note (penelopeysm): The purpose of using SampleFromUniform (SFU) + # specifically in this test is because SFU samples from the linked + # distribution i.e. in unconstrained space. However, it does this not + # by linking the varinfo but by transforming the distributions on the + # fly. That's why it's worth specifically checking that it can do this + # without having to change the VarInfo object. vi = VarInfo() meta = vi.metadata - model(vi, SampleFromUniform()) + _, vi = DynamicPPL.evaluate_and_sample!!(model, vi, SampleFromUniform()) @test all(x -> !istrans(vi, x), meta.vns) # Check that linking and invlinking set the `trans` flag accordingly @@ -448,46 +565,58 @@ end ## `untyped_varinfo` vi = DynamicPPL.untyped_varinfo(model) + + ## `untyped_varinfo` + vi = DynamicPPL.untyped_varinfo(model) + vi = DynamicPPL.settrans!!(vi, true, vn) + # Sample in unconstrained space. + vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) + f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) + x = f(DynamicPPL.getindex_internal(vi, vn)) + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + + ## `typed_varinfo` + vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ## `typed_varinfo` vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ### `SimpleVarInfo` ## `SimpleVarInfo{<:NamedTuple}` vi = DynamicPPL.settrans!!(SimpleVarInfo(), true) # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ## `SimpleVarInfo{<:Dict}` vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true) # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ## `SimpleVarInfo{<:VarNamedVector}` vi = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) end @testset "values_as" begin @@ -566,7 +695,7 @@ end end # Evaluate the model once to update the logp of the varinfo. - varinfo = last(DynamicPPL.evaluate!!(model, varinfo, DefaultContext())) + varinfo = last(DynamicPPL.evaluate!!(model, varinfo)) varinfo_linked = if mutating DynamicPPL.link!!(deepcopy(varinfo), model) @@ -589,8 +718,8 @@ end lp = logjoint(model, varinfo) @test lp ≈ lp_true - @test getlogp(varinfo) ≈ lp_true - lp_linked = getlogp(varinfo_linked) + @test getlogjoint(varinfo) ≈ lp_true + lp_linked = getlogjoint(varinfo_linked) @test lp_linked ≈ lp_linked_true # TODO: Compare values once we are no longer working with `NamedTuple` for @@ -602,13 +731,36 @@ end varinfo_linked_unflattened, model ) @test length(varinfo_invlinked[:]) == length(varinfo[:]) - @test getlogp(varinfo_invlinked) ≈ lp_true + @test getlogjoint(varinfo_invlinked) ≈ lp_true end end end end end + @testset "unflatten type stability" begin + @model function demo(y) + x ~ Normal() + y ~ Normal(x, 1) + return nothing + end + + model = demo(0.0) + varinfos = DynamicPPL.TestUtils.setup_varinfos( + model, (; x=1.0), (@varname(x),); include_threadsafe=true + ) + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + # Skip the severely inconcrete `SimpleVarInfo` types, since checking for type + # stability for them doesn't make much sense anyway. + if varinfo isa SimpleVarInfo{OrderedDict{Any,Any}} || + varinfo isa + DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo{OrderedDict{Any,Any}}} + continue + end + @inferred DynamicPPL.unflatten(varinfo, varinfo[:]) + end + end + @testset "subset" begin @model function demo_subsetting_varinfo(::Type{TV}=Vector{Float64}) where {TV} s ~ InverseGamma(2, 3) @@ -846,9 +998,7 @@ end # Sampling from `model2` should hit the `istrans(vi) == true` branches # because all the existing variables are linked. model2 = demo(2) - varinfo2 = last( - DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext()) - ) + varinfo2 = last(DynamicPPL.evaluate_and_sample!!(model2, deepcopy(varinfo1))) for vn in [@varname(x[1]), @varname(x[2])] @test DynamicPPL.istrans(varinfo2, vn) end @@ -867,9 +1017,7 @@ end # Sampling from `model2` should hit the `istrans(vi) == true` branches # because all the existing variables are linked. model2 = demo_dot(2) - varinfo2 = last( - DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext()) - ) + varinfo2 = last(DynamicPPL.evaluate_and_sample!!(model2, deepcopy(varinfo1))) for vn in [@varname(x), @varname(y[1])] @test DynamicPPL.istrans(varinfo2, vn) end @@ -934,19 +1082,19 @@ end # First iteration, variables are added to vi # variables samples in order: z1,a1,z2,a2,z3 - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z1, dists[1]) randr(vi, vn_a1, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_b, dists[2]) randr(vi, vn_z2, dists[1]) randr(vi, vn_a2, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) @test vi.metadata.orders == [1, 1, 2, 2, 2, 3] @test DynamicPPL.get_num_produce(vi) == 3 - DynamicPPL.reset_num_produce!(vi) + vi = DynamicPPL.reset_num_produce!!(vi) DynamicPPL.set_retained_vns_del!(vi) @test DynamicPPL.is_flagged(vi, vn_z1, "del") @test DynamicPPL.is_flagged(vi, vn_a1, "del") @@ -954,12 +1102,12 @@ end @test DynamicPPL.is_flagged(vi, vn_a2, "del") @test DynamicPPL.is_flagged(vi, vn_z3, "del") - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z1, dists[1]) randr(vi, vn_a1, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z2, dists[1]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) randr(vi, vn_a2, dists[2]) @test vi.metadata.orders == [1, 1, 2, 2, 3, 3] @@ -968,21 +1116,21 @@ end vi = empty!!(DynamicPPL.typed_varinfo(vi)) # First iteration, variables are added to vi # variables samples in order: z1,a1,z2,a2,z3 - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z1, dists[1]) randr(vi, vn_a1, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_b, dists[2]) randr(vi, vn_z2, dists[1]) randr(vi, vn_a2, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) @test vi.metadata.z.orders == [1, 2, 3] @test vi.metadata.a.orders == [1, 2] @test vi.metadata.b.orders == [2] @test DynamicPPL.get_num_produce(vi) == 3 - DynamicPPL.reset_num_produce!(vi) + vi = DynamicPPL.reset_num_produce!!(vi) DynamicPPL.set_retained_vns_del!(vi) @test DynamicPPL.is_flagged(vi, vn_z1, "del") @test DynamicPPL.is_flagged(vi, vn_a1, "del") @@ -990,12 +1138,12 @@ end @test DynamicPPL.is_flagged(vi, vn_a2, "del") @test DynamicPPL.is_flagged(vi, vn_z3, "del") - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z1, dists[1]) randr(vi, vn_a1, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z2, dists[1]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) randr(vi, vn_a2, dists[2]) @test vi.metadata.z.orders == [1, 2, 3] @@ -1010,8 +1158,8 @@ end n = length(varinfo[:]) # `Bool`. - @test getlogp(DynamicPPL.unflatten(varinfo, fill(true, n))) isa typeof(float(1)) + @test getlogjoint(DynamicPPL.unflatten(varinfo, fill(true, n))) isa typeof(float(1)) # `Int`. - @test getlogp(DynamicPPL.unflatten(varinfo, fill(1, n))) isa typeof(float(1)) + @test getlogjoint(DynamicPPL.unflatten(varinfo, fill(1, n))) isa typeof(float(1)) end end diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index bd3f5553f..57a8175d4 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -603,20 +603,18 @@ end DynamicPPL.TestUtils.test_values(varinfo, value_true, vns) # Is evaluation correct? - varinfo_eval = last( - DynamicPPL.evaluate!!(model, deepcopy(varinfo), DefaultContext()) - ) + varinfo_eval = last(DynamicPPL.evaluate!!(model, deepcopy(varinfo))) # Log density should be the same. - @test getlogp(varinfo_eval) ≈ logp_true + @test getlogjoint(varinfo_eval) ≈ logp_true # Values should be the same. DynamicPPL.TestUtils.test_values(varinfo_eval, value_true, vns) # Is sampling correct? varinfo_sample = last( - DynamicPPL.evaluate!!(model, deepcopy(varinfo), SamplingContext()) + DynamicPPL.evaluate_and_sample!!(model, deepcopy(varinfo)) ) # Log density should be different. - @test getlogp(varinfo_sample) != getlogp(varinfo) + @test getlogjoint(varinfo_sample) != getlogjoint(varinfo) # Values should be different. DynamicPPL.TestUtils.test_values( varinfo_sample, value_true, vns; compare=!isequal