Skip to content

Commit 904f1f8

Browse files
committed
Revise _init_subsampling
1 parent 33028ba commit 904f1f8

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

bayesml/gaussianmixture/_gaussianmixture.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,8 @@ def visualize_model(self,sample_size=100):
283283
>>> from bayesml import gaussianmixture
284284
>>> import numpy as np
285285
>>> model = gaussianmixture.GenModel(
286+
>>> c_num_classes=3,
287+
>>> c_degree=1
286288
>>> pi_vec=np.array([0.444,0.444,0.112]),
287289
>>> mu_vecs=np.array([[-2.8],[-0.8],[2]]),
288290
>>> lambda_mats=np.array([[[6.25]],[[6.25]],[[100]]])
@@ -763,11 +765,11 @@ def _init_subsampling(self,x):
763765
for k in range(self.c_num_classes):
764766
_subsample = self.rng.choice(x,size=_size,replace=False,axis=0,shuffle=False)
765767
self.hn_m_vecs[k] = _subsample.sum(axis=0) / _size
766-
self.hn_w_mats[k] = ((_subsample - self.hn_m_vecs[k]).T
768+
self.hn_w_mats_inv[k] = ((_subsample - self.hn_m_vecs[k]).T
767769
@ (_subsample - self.hn_m_vecs[k])
768-
/ _size / self.hn_nus[k]
770+
/ _size * self.hn_nus[k]
769771
+ np.identity(self.c_degree) * 1.0E-5) # avoid singular matrix
770-
self.hn_w_mats_inv[k] = np.linalg.inv(self.hn_w_mats[k])
772+
self.hn_w_mats[k] = np.linalg.inv(self.hn_w_mats_inv[k])
771773
self._calc_q_pi_char()
772774
self._calc_q_lambda_char()
773775

bayesml/gaussianmixture/test.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,15 @@
44
from time import time
55

66
gen_model = gaussianmixture.GenModel(
7-
c_num_classes=2,
8-
c_degree=2,
9-
pi_vec=np.ones(3) / 3,
7+
c_num_classes=3,
8+
c_degree=1,
9+
pi_vec=np.array([0.444,0.444,0.112]),
10+
mu_vecs=np.array([[-2.8],[-0.8],[2]]),
11+
lambda_mats=np.array([[[6.25]],[[6.25]],[[100]]])
1012
)
11-
print(gen_model.get_params())
13+
14+
x,z = gen_model.gen_sample(1000)
15+
16+
learn_model = gaussianmixture.LearnModel(3,1)
17+
learn_model.update_posterior(x)
18+
learn_model.visualize_posterior()

0 commit comments

Comments
 (0)