-
Notifications
You must be signed in to change notification settings - Fork 426
Reduce allocations of cov(::MultivariateMixture)
#1967
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #1967 +/- ##
==========================================
+ Coverage 86.26% 86.27% +0.01%
==========================================
Files 146 146
Lines 8763 8778 +15
==========================================
+ Hits 7559 7573 +14
- Misses 1204 1205 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
cov(::MultivariateMixture)
Would you be able to quickly review this small PR @jishnub? |
Looks fine to me. I suppose |
src/mixtures/mixturemodel.jl
Outdated
md = mean(c) - m | ||
axpy!(pi, md*md', V) | ||
md .= mean(c) .- m | ||
BLAS.syr!('U', pi, md, V) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the pi
be cast to the appropriate type here? I think BLAS
expects floating-point numbers. But then, if tests pass, probably this is fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In principle, I'd like to avoid BLAS calls even though md
and V
are (currently) Array
s. But it seems mul!(::Symmetric, ::Vector, ::Adjoint, ...)
is not supported and according to a benchmark on my machine repeated generic mul!
calls can be slower than syr!
and a single copytri!
... Everything apart from pi
is hardcoded to Float64
currently though (something to be fixed as well), so
BLAS.syr!('U', pi, md, V) | |
BLAS.syr!('U', Float64(pi), md, V) |
should make it safe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I agree that it is a bit unfortunate that there isn't a higher level API for either of these. This should be fine for now with the change above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a test to cover this line? The fact that the tests passed without this change seems to suggest that this isn't covered.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll check. I assume that - as in many tests - only distributions with Float64 parameters are tested. Here, eg, tests with a multivariate mixture with mixture probabilities of type Float64 would have passed even without the change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was indeed the case, it was only tested with pi::Float64
. I added a test with Float32
in a065a65 that would error without the Float64(pi)
change.
Indeed, I used it only since Distributions already uses it in a few other places. |
The PR reduces allocations of
cov(::MultivatiateMixture)
(addresses a TODO item).