-
Notifications
You must be signed in to change notification settings - Fork 9
Move hasvalue
and getvalue
from DynamicPPL; implement extra Distributions-based methods
#125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #125 +/- ##
==========================================
+ Coverage 83.56% 86.28% +2.72%
==========================================
Files 2 5 +3
Lines 292 401 +109
==========================================
+ Hits 244 346 +102
- Misses 48 55 +7 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
9310e27
to
07975a2
Compare
Pull Request Test Coverage Report for Build 16104307329Details
💛 - Coveralls |
1e3ed7c
to
9291e07
Compare
AbstractPPL.jl documentation for PR #125 is available at: |
6fe03f9
to
398b42f
Compare
398b42f
to
6a01588
Compare
Remove identity lenses from composed optics. | ||
""" | ||
_strip_identity(::Base.ComposedFunction{typeof(identity),typeof(identity)}) = identity | ||
function _strip_identity(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer} | ||
return _strip_identity(o.outer) | ||
end | ||
function _strip_identity(o::Base.ComposedFunction{typeof(identity),Inner}) where {Inner} | ||
return _strip_identity(o.inner) | ||
end | ||
_strip_identity(o::Base.ComposedFunction) = o | ||
_strip_identity(o::Accessors.PropertyLens) = o | ||
_strip_identity(o::Accessors.IndexLens) = o | ||
_strip_identity(o::typeof(identity)) = o |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
normalise
strips identities now so this function isn't needed any more
_head(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.inner | ||
_head(o::Accessors.PropertyLens) = o | ||
_head(o::Accessors.IndexLens) = o | ||
_head(::typeof(identity)) = identity |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_head
, _tail
, _init
, and _last
take their names from the equivalent Haskell functions on linked lists:
λ> head [1,2,3]
1
λ> tail [1,2,3]
[2,3]
λ> init [1,2,3]
[1,2]
λ> last [1,2,3]
3
-- empty list is turned into identity in our case
λ> head [1]
1
λ> tail [1]
[]
λ> init [1]
[]
λ> last [1]
1
Julia minimum version bump
I bumped to 1.10 as I don't want to add extra code to handle extensions on pre-1.9. Most important packages in TuringLang are already using >= 1.10 anyway.
Moving functions from DynamicPPL
This PR moves
hasvalue
andgetvalue
from DynamicPPL to AbstractPPL. https://github.com/TuringLang/DynamicPPL.jl/blob/92f6eea8660be2142fa4087e5e025f37026bfa45/src/utils.jl#L763-L954A lot of the helper functions in DynamicPPL are not actually needed because there is existing functionality in here that accomplishes much the same. I modified the implementations accordingly.
Distributions-based methods
This part is new and warrants more explanation. To begin, notice the default behaviour of
hasvalue
:This makes sense, because
d
alone does not give us enough information to reconstruct some arbitrary variablex
.However, let's say that we know
x
is to be sampled from a given distributiondist
. In this case, we do have enough information to determine whetherx
can be reconstructed. This PR therefore also implements the following methods:The motivation for this is to (properly) fix issues where values for multivariate distributions are specified separately, see e.g., TuringLang/DynamicPPL.jl#712, see also this comment TuringLang/DynamicPPL.jl#710 (comment).
One might argue that we should force users to specify things properly, i.e., if
x ~ MvNormal(zeros(2), I)
then the user should condition onDict(@varname(x) => [1.0, 1.0])
rather thanDict(@varname(x[1]) => 1.0, @varname(x[2]) => 1.0)
. In an ideal world I would do that, and even now, I would still advocate for making this general guideline clear in e.g. the docs.However, there remains one specific case where this isn't enough, namely in DynamicPPL's
predict(model, chain)
orreturned(model, chain)
. These methods require extracting variable values fromchain
, inserting them into a VarInfo, and rerunning the model with the given values. Unfortunately,chain
is a lossy storage format, because array-valued variables likex
are split up intox[1]
andx[2]
and it's not possible to recover the original shape ofx
.Up until this PR, this has been handled in DynamicPPL using the
setval_and_resample!
andnested_setindex_maybe
methods which perform some direct manipulation of VarInfos. I think these methods are slightly dangerous and can lead to subtle bugs, for example, if only part of the variablex
is given, it marks the entire variablex
as to be not-resampled: https://github.com/TuringLang/DynamicPPL.jl/blob/92f6eea8660be2142fa4087e5e025f37026bfa45/src/varinfo.jl#L2177-L2181The good news, though, is that when evaluating a model, we have access to the distribution that
x
is supposed to be sampled from. Thus, we can determine whether enough of thex[i]
's are given to reconstruct it, which is what these new methods do. So, we can deal with this in a more principled fashion: if we can find all the indices needed to reconstruct the value ofx
, then we can confidently set that value; if we can't, then we don't even attempt to set any of the individual indices becausehasvalue
will return false.Remaining questions:
hasvalue
andgetvalue
have extremely similar logic, do we really need to have two functions with almost the same implementation? I've held off on attempting to do this because I'm worried about type stability, i.e.getvalue
is inherently type-unstable, and maybe guarding calls togetvalue
behind a call tohasvalue
avoids leaking type instability into the caller function. However, I think this is reliant on the compiler being able to infer the return value ofhasvalue
through e.g. constant propagation?!TODO
hasvalue
for other distributionsgetvalue
for distributionsThis PR doesn't support
ProductNamedTupleDistribution
. It shouldn't be overly complicated to implement IMO. However, almost nothing else in TuringLang works withProductNamedTupleDistribution
, so I don't feel bad not implementing it.Closes #124
This is required for the
InitContext
PR TuringLang/DynamicPPL.jl#967 asParamsInit
needs to usehasvalue
andgetvalue
. Specifically, I also want to use ParamsInit to handlepredict
, hence the need for the Distributions-based methods.