Skip to content

Commit b38b5f4

Browse files
committed
bug fix for lambda_mats initialization
1 parent 5ca3907 commit b38b5f4

File tree

2 files changed

+5
-10
lines changed

2 files changed

+5
-10
lines changed

bayesml/gaussianmixture/_gaussianmixture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(
6767
# params
6868
self.pi_vec = np.ones(self.c_num_classes) / self.c_num_classes
6969
self.mu_vecs = np.zeros([self.c_num_classes,self.c_degree])
70-
self.lambda_mats = np.tile(np.identity(self.c_degree),[self.c_num_classes,self.c_degree,self.c_degree])
70+
self.lambda_mats = np.tile(np.identity(self.c_degree),[self.c_num_classes,1,1])
7171

7272
# h_params
7373
self.h_alpha_vec = np.ones(self.c_num_classes) / 2

bayesml/gaussianmixture/test.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,8 @@
44
from time import time
55

66
gen_model = gaussianmixture.GenModel(
7-
num_classes=2,
8-
degree=2,
9-
mu_vecs=np.array([[-2,-2],[2,2]]),
7+
c_num_classes=2,
8+
c_degree=2,
9+
pi_vec=np.ones(3) / 3,
1010
)
11-
x,z = gen_model.gen_sample(sample_size=100)
12-
print(x.shape)
13-
14-
learn_model = gaussianmixture.LearnModel(num_classes=10, degree=2)
15-
learn_model.update_posterior(x)
16-
learn_model.visualize_posterior()
11+
print(gen_model.get_params())

0 commit comments

Comments
 (0)