Skip to content

Commit 8e40446

Browse files
committed
Fix bugs
1 parent b77ce3e commit 8e40446

File tree

6 files changed

+20
-9
lines changed

6 files changed

+20
-9
lines changed

bayesml/_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from bayesml import contexttree
2+
3+
gen_model = contexttree.GenModel(c_k=2,c_d_max=3,h_g=0.75)
4+
gen_model.gen_params()
5+
gen_model.visualize_model()
6+
x = gen_model.gen_sample(500)
7+
learn_model = contexttree.LearnModel(c_k=2,c_d_max=3)
8+
learn_model.update_posterior(x)
9+
# learn_model.visualize_posterior()

bayesml/bernoulli/_bernoulli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def visualize_posterior(self):
368368
p_range = np.linspace(0,1,100,endpoint=False)
369369
fig, ax = plt.subplots()
370370
ax.plot(p_range,self.estimate_params(loss="KL").pdf(p_range))
371-
ax.set_xlabel("p_theta")
371+
ax.set_xlabel("theta")
372372
ax.set_ylabel("posterior")
373373
plt.show()
374374

bayesml/categorical/_categorical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def update_posterior(self, x):
349349
2-dimensional array whose shape is ``(sample_size,degree)`` whose rows are one-hot vectors.
350350
"""
351351
_check.onehot_vecs(x,'x',DataFormatError)
352-
if self.degree > 1 and x.shape[-1] != self.degree:
352+
if x.shape[-1] != self.degree:
353353
raise(DataFormatError(f"x.shape[-1] must be degree:{self.degree}"))
354354
x = x.reshape(-1,self.degree)
355355

bayesml/linearregression/_linearregression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ def update_posterior(self, x, y):
542542
float array.
543543
"""
544544
_check.float_vecs(x,'x',DataFormatError)
545-
if self.degree > 1 and x.shape[-1] != self.degree:
545+
if x.shape[-1] != self.degree:
546546
raise(DataFormatError(f"x.shape[-1] must be degree:{self.degree}"))
547547
_check.floats(y,'y',DataFormatError)
548548
if type(y) is np.ndarray:

bayesml/metatree/_metatree_x_discrete.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,8 @@ def __init__(
173173
)
174174

175175
# params
176-
self.root = _GenNode(0,list(range(self.c_k)),self.c_num_children,self.h_g,None,None)
176+
self.root = _GenNode(0,list(range(self.c_k)),self.c_num_children,self.h_g,0,self.SubModel(**self.sub_h_params))
177+
self.root.leaf = True
177178

178179
self.set_params(root)
179180

@@ -1282,6 +1283,7 @@ def visualize_posterior(self,filename=None,format=None):
12821283
--------
12831284
>>> from bayesml import metatree
12841285
>>> gen_model = metatree.GenModel(c_k=3,h_g=0.75)
1286+
>>> gen_model.gen_params()
12851287
>>> x,y = gen_model.gen_sample(500)
12861288
>>> learn_model = metatree.LearnModel(c_k=3)
12871289
>>> learn_model.update_posterior(x,y)

bayesml/multivariate_normal/_multivariatenormal.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ def update_posterior(self,x):
614614
All the elements must be real number.
615615
"""
616616
_check.float_vecs(x,'x',DataFormatError)
617-
if self.degree > 1 and x.shape[-1] != self.degree:
617+
if x.shape[-1] != self.degree:
618618
raise(DataFormatError(f"x.shape[-1] must be degree:{self.degree}"))
619619
x = x.reshape(-1,self.degree)
620620

@@ -721,7 +721,7 @@ def visualize_posterior(self):
721721
print(f"{self.hn_w_mat}")
722722
print("E[lambda_mat]=")
723723
print(f"{self.hn_nu * self.hn_w_mat}")
724-
mu_vec_pdf, w_mat_pdf = self.estimate_params(loss="KL")
724+
mu_vec_pdf, lambda_mat_pdf = self.estimate_params(loss="KL")
725725
if self.degree == 1:
726726
fig, axes = plt.subplots(1,2)
727727
# for mu_vec
@@ -732,12 +732,12 @@ def visualize_posterior(self):
732732
axes[0].set_xlabel("mu_vec")
733733
axes[0].set_ylabel("Density")
734734
# for lambda_mat
735-
x = np.linspace(max(1.0e-8,self.hn_nu*self.hn_w_mat)-4.0*np.sqrt(self.hn_nu/2.0)*(2.0*self.hn_w_mat),
735+
x = np.linspace(max(1.0e-8,self.hn_nu*self.hn_w_mat-4.0*np.sqrt(self.hn_nu/2.0)*(2.0*self.hn_w_mat)),
736736
self.hn_nu*self.hn_w_mat+4.0*np.sqrt(self.hn_nu/2.0)*(2.0*self.hn_w_mat),
737737
100)
738738
print(self.hn_w_mat)
739-
axes[1].plot(x[:,0,0],w_mat_pdf.pdf(x[:,0,0]))
740-
axes[1].set_xlabel("w_mat")
739+
axes[1].plot(x[:,0,0],lambda_mat_pdf.pdf(x[:,0,0]))
740+
axes[1].set_xlabel("lambda_mat")
741741
axes[1].set_ylabel("Density")
742742

743743
fig.tight_layout()

0 commit comments

Comments
 (0)