Skip to content

Commit 8ce2296

Browse files
committed
Initialization bug fix
1 parent 904f1f8 commit 8ce2296

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

bayesml/gaussianmixture/_gaussianmixture.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -770,9 +770,12 @@ def _init_subsampling(self,x):
770770
/ _size * self.hn_nus[k]
771771
+ np.identity(self.c_degree) * 1.0E-5) # avoid singular matrix
772772
self.hn_w_mats[k] = np.linalg.inv(self.hn_w_mats_inv[k])
773-
self._calc_q_pi_char()
774773
self._calc_q_lambda_char()
775774

775+
def _init_rho_r(self):
776+
self._ln_rho[:] = 0.0
777+
self.r_vecs[:] = 1/self.c_num_classes
778+
776779
def update_posterior(
777780
self,
778781
x,
@@ -818,6 +821,8 @@ def update_posterior(
818821

819822
convergence_flag = True
820823
for i in range(num_init):
824+
self.reset_hn_params()
825+
self._init_rho_r()
821826
if init_type == 'subsampling':
822827
self._init_subsampling(x)
823828
self._update_q_z(x)

bayesml/gaussianmixture/test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
lambda_mats=np.array([[[6.25]],[[6.25]],[[100]]])
1212
)
1313

14-
x,z = gen_model.gen_sample(1000)
14+
x,z = gen_model.gen_sample(300)
1515

1616
learn_model = gaussianmixture.LearnModel(3,1)
17-
learn_model.update_posterior(x)
17+
learn_model.update_posterior(x)#,init_type='random_responsibility')
1818
learn_model.visualize_posterior()

0 commit comments

Comments
 (0)