Skip to content

Move hasvalue and getvalue from DynamicPPL; implement extra Distributions-based methods #125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Jul 5, 2025

Julia minimum version bump

I bumped to 1.10 as I don't want to add extra code to handle extensions on pre-1.9. Most important packages in TuringLang are already using >= 1.10 anyway.

Moving functions from DynamicPPL

This PR moves hasvalue and getvalue from DynamicPPL to AbstractPPL. https://github.com/TuringLang/DynamicPPL.jl/blob/92f6eea8660be2142fa4087e5e025f37026bfa45/src/utils.jl#L763-L954

A lot of the helper functions in DynamicPPL are not actually needed because there is existing functionality in here that accomplishes much the same. I modified the implementations accordingly.

Distributions-based methods

This part is new and warrants more explanation. To begin, notice the default behaviour of hasvalue:

julia> d = Dict(@varname(x[1]) => 1.0, @varname(x[2]) => 1.0)
Dict{VarName{:x, Accessors.IndexLens{Tuple{Int64}}}, Float64} with 2 entries:
  x[1] => 1.0
  x[2] => 1.0

julia> hasvalue(d, @varname(x))
false

This makes sense, because d alone does not give us enough information to reconstruct some arbitrary variable x.

However, let's say that we know x is to be sampled from a given distribution dist. In this case, we do have enough information to determine whether x can be reconstructed. This PR therefore also implements the following methods:

julia> using Distributions, LinearAlgebra

julia> hasvalue(d, @varname(x), MvNormal(zeros(2), I))
true

julia> getvalue(d, @varname(x), MvNormal(zeros(2), I))
[1.0, 1.0]

julia> hasvalue(d, @varname(x), MvNormal(zeros(3), I))
false

The motivation for this is to (properly) fix issues where values for multivariate distributions are specified separately, see e.g., TuringLang/DynamicPPL.jl#712, see also this comment TuringLang/DynamicPPL.jl#710 (comment).

One might argue that we should force users to specify things properly, i.e., if x ~ MvNormal(zeros(2), I) then the user should condition on Dict(@varname(x) => [1.0, 1.0]) rather than Dict(@varname(x[1]) => 1.0, @varname(x[2]) => 1.0). In an ideal world I would do that, and even now, I would still advocate for making this general guideline clear in e.g. the docs.

However, there remains one specific case where this isn't enough, namely in DynamicPPL's predict(model, chain) or returned(model, chain). These methods require extracting variable values from chain, inserting them into a VarInfo, and rerunning the model with the given values. Unfortunately, chain is a lossy storage format, because array-valued variables like x are split up into x[1] and x[2] and it's not possible to recover the original shape of x.

Up until this PR, this has been handled in DynamicPPL using the setval_and_resample! and nested_setindex_maybe methods which perform some direct manipulation of VarInfos. I think these methods are slightly dangerous and can lead to subtle bugs, for example, if only part of the variable x is given, it marks the entire variable x as to be not-resampled: https://github.com/TuringLang/DynamicPPL.jl/blob/92f6eea8660be2142fa4087e5e025f37026bfa45/src/varinfo.jl#L2177-L2181

The good news, though, is that when evaluating a model, we have access to the distribution that x is supposed to be sampled from. Thus, we can determine whether enough of the x[i]'s are given to reconstruct it, which is what these new methods do. So, we can deal with this in a more principled fashion: if we can find all the indices needed to reconstruct the value of x, then we can confidently set that value; if we can't, then we don't even attempt to set any of the individual indices because hasvalue will return false.

Remaining questions:

  • I wonder if we can simplify the API? Note that hasvalue and getvalue have extremely similar logic, do we really need to have two functions with almost the same implementation? I've held off on attempting to do this because I'm worried about type stability, i.e. getvalue is inherently type-unstable, and maybe guarding calls to getvalue behind a call to hasvalue avoids leaking type instability into the caller function. However, I think this is reliant on the compiler being able to infer the return value of hasvalue through e.g. constant propagation?!
  • Not sure if this should be a minor bump. According to semver, nothing in here is breaking, hence I did patch bump. But the changes are quite large and maybe it feels more correct to do a minor bump.

TODO

  • hasvalue for other distributions
  • getvalue for distributions
  • Appropriate tests
  • API documentation for the distributions bits
  • Changelog

This PR doesn't support ProductNamedTupleDistribution. It shouldn't be overly complicated to implement IMO. However, almost nothing else in TuringLang works with ProductNamedTupleDistribution, so I don't feel bad not implementing it.

Closes #124

This is required for the InitContext PR TuringLang/DynamicPPL.jl#967 as ParamsInit needs to use hasvalue and getvalue. Specifically, I also want to use ParamsInit to handle predict, hence the need for the Distributions-based methods.

@penelopeysm penelopeysm marked this pull request as draft July 5, 2025 22:34
Copy link

codecov bot commented Jul 5, 2025

Codecov Report

Attention: Patch coverage is 85.38462% with 19 lines in your changes missing coverage. Please review.

Project coverage is 86.28%. Comparing base (7be9556) to head (6a01588).

Files with missing lines Patch % Lines
ext/AbstractPPLDistributionsExt.jl 76.92% 15 Missing ⚠️
src/hasvalue.jl 93.02% 3 Missing ⚠️
src/varname.jl 95.45% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #125      +/-   ##
==========================================
+ Coverage   83.56%   86.28%   +2.72%     
==========================================
  Files           2        5       +3     
  Lines         292      401     +109     
==========================================
+ Hits          244      346     +102     
- Misses         48       55       +7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Base automatically changed from py/composed-assoc to main July 6, 2025 10:58
@coveralls
Copy link

coveralls commented Jul 6, 2025

Pull Request Test Coverage Report for Build 16104307329

Details

  • 111 of 130 (85.38%) changed or added relevant lines in 3 files are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage increased (+2.7%) to 86.284%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/varname.jl 21 22 95.45%
src/hasvalue.jl 40 43 93.02%
ext/AbstractPPLDistributionsExt.jl 50 65 76.92%
Totals Coverage Status
Change from base Build 16098285788: 2.7%
Covered Lines: 346
Relevant Lines: 401

💛 - Coveralls

@TuringLang TuringLang deleted a comment from github-actions bot Jul 6, 2025
Copy link
Contributor

github-actions bot commented Jul 6, 2025

AbstractPPL.jl documentation for PR #125 is available at:
https://TuringLang.github.io/AbstractPPL.jl/previews/PR125/

@penelopeysm penelopeysm marked this pull request as ready for review July 6, 2025 19:03
@penelopeysm penelopeysm requested a review from mhauru July 7, 2025 10:34
Comment on lines -968 to -980
Remove identity lenses from composed optics.
"""
_strip_identity(::Base.ComposedFunction{typeof(identity),typeof(identity)}) = identity
function _strip_identity(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer}
return _strip_identity(o.outer)
end
function _strip_identity(o::Base.ComposedFunction{typeof(identity),Inner}) where {Inner}
return _strip_identity(o.inner)
end
_strip_identity(o::Base.ComposedFunction) = o
_strip_identity(o::Accessors.PropertyLens) = o
_strip_identity(o::Accessors.IndexLens) = o
_strip_identity(o::typeof(identity)) = o
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

normalise strips identities now so this function isn't needed any more

Comment on lines +995 to +998
_head(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.inner
_head(o::Accessors.PropertyLens) = o
_head(o::Accessors.IndexLens) = o
_head(::typeof(identity)) = identity
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_head, _tail, _init, and _last take their names from the equivalent Haskell functions on linked lists:

λ> head [1,2,3]
1
λ> tail [1,2,3]
[2,3]
λ> init [1,2,3]
[1,2]
λ> last [1,2,3]
3

-- empty list is turned into identity in our case
λ> head [1]
1
λ> tail [1]
[]
λ> init [1]
[]
λ> last [1]
1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

hasvalue and getvalue
2 participants