Multivariate linear regression with gaussian mixture pool #446
Answered
by
bvdmitri
max-de-rooij
asked this question in
Q&A
-
Beta Was this translation helpful? Give feedback.
Answered by
bvdmitri
Mar 22, 2025
Replies: 2 comments 6 replies
-
My guess would be is that you also need to initialize messages on init_mixture = @initialization begin
μ(as) = [MvNormalMeanCovariance(zeros(2), 10.0*diageye(2)) for i in 1:priors.n_classes]
q(bs) = [MvNormalMeanCovariance(zeros(2), 10.0*diageye(2)) for i in 1:priors.n_classes]
q(ws) = [Wishart(2 + 2, 1e2*diageye(2)) for i in 1:priors.n_classes]
q(σ) = vague(Gamma)
q(s) = vague(Dirichlet, priors.n_classes)
q(z) = vague(RxInfer.Categorical, priors.n_classes)
μ(α) = ...
μ(β) = ...
end if that doesn't help, then perhaps also initialize posteriors on init_mixture = @initialization begin
μ(as) = [MvNormalMeanCovariance(zeros(2), 10.0*diageye(2)) for i in 1:priors.n_classes]
q(bs) = [MvNormalMeanCovariance(zeros(2), 10.0*diageye(2)) for i in 1:priors.n_classes]
q(ws) = [Wishart(2 + 2, 1e2*diageye(2)) for i in 1:priors.n_classes]
q(σ) = vague(Gamma)
q(s) = vague(Dirichlet, priors.n_classes)
q(z) = vague(RxInfer.Categorical, priors.n_classes)
μ(α) = ...
μ(β) = ...
q(α) = ...
q(β) = ...
end |
Beta Was this translation helpful? Give feedback.
6 replies
-
Ah, another thing, this piece of the model most likely will not work out covariance = diageye(dim) * σ we don't have specialized message passing rules for this case, but we have a specialized factor node for this. data[i] ~ MvNormal(mean = estimation[i], covariance = diageye(dim) * σ) with data[i] ~ MvNormalMeanScalePrecision(estimation[i], σ) # equivalent to the line above and uses conjugate `Gamma` prior too |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hey, I looked into it, and it does seem strange. However, I think I identified the main issue.
In the example you shared, the code didn’t run at all because the individuals array contained indices larger than the length of α, causing an out-of-bounds error when accessing α[individuals[i]]. I fixed this by mapping individuals to a valid range:
I then used fixed_individuals instead of individuals in the …