Skip to content

Commit 5ca3907

Browse files
committed
bug fix for checking pi_vec
1 parent 663a432 commit 5ca3907

File tree

1 file changed

+25
-19
lines changed

1 file changed

+25
-19
lines changed

bayesml/gaussianmixture/_gaussianmixture.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ def set_h_params(
120120
_check.float_vecs(h_m_vecs,'h_m_vecs',ParameterFormatError)
121121
if h_m_vecs.shape[-1] != self.c_degree:
122122
raise(ParameterFormatError(
123-
"h_m_vecs.shape[-1] must coincide with self.c_degree:"
124-
+f"h_m_vecs.shape[-1]={h_m_vecs.shape[-1]}, self.c_degree={self.c_degree}"))
123+
"h_m_vecs.shape[-1] must coincide with self.c_degree: "
124+
+f"h_m_vecs.shape[-1] = {h_m_vecs.shape[-1]}, self.c_degree = {self.c_degree}"))
125125
self.h_m_vecs[:] = h_m_vecs
126126

127127
if h_kappas is not None:
@@ -132,16 +132,16 @@ def set_h_params(
132132
_check.pos_floats(h_nus,'h_nus',ParameterFormatError)
133133
if np.any(h_nus <= self.c_degree - 1):
134134
raise(ParameterFormatError(
135-
"c_degree must be smaller than h_nus + 1"
136-
+ f"self.c_degree={self.c_degree}, h_nus={h_nus}"))
135+
"All the values of h_nus must be greater than self.c_degree - 1: "
136+
+ f"self.c_degree = {self.c_degree}, h_nus = {h_nus}"))
137137
self.h_nus[:] = h_nus
138138

139139
if h_w_mats is not None:
140140
_check.pos_def_sym_mats(h_w_mats,'h_w_mats',ParameterFormatError)
141141
if h_w_mats.shape[-1] != self.c_degree:
142142
raise(ParameterFormatError(
143-
"h_w_mats.shape[-1] and h_w_mats.shape[-2] must coincide with self.c_degree:"
144-
+f"h_w_mats.shape[-1]={h_w_mats.shape[-1]}, h_w_mats.shape[-2]={h_w_mats.shape[-2]}, self.c_degree={self.c_degree}"))
143+
"h_w_mats.shape[-1] and h_w_mats.shape[-2] must coincide with self.c_degree: "
144+
+f"h_w_mats.shape[-1] and h_w_mats.shape[-2] = {h_w_mats.shape[-1]}, self.c_degree = {self.c_degree}"))
145145
self.h_w_mats[:] = h_w_mats
146146

147147
def get_h_params(self):
@@ -191,22 +191,26 @@ def set_params(
191191
"""
192192
if pi_vec is not None:
193193
_check.float_vec_sum_1(pi_vec,'pi_vec',ParameterFormatError)
194+
if pi_vec.shape[0] != self.c_num_classes:
195+
raise(ParameterFormatError(
196+
"pi_vec.shape[0] must coincide with self.c_num_classes: "
197+
+f"pi_vec.shape[0] = {pi_vec.shape[0]}, self.c_num_classes = {self.c_num_classes}"))
194198
self.pi_vec[:] = pi_vec
195199

196200
if mu_vecs is not None:
197201
_check.float_vecs(mu_vecs,'mu_vecs',ParameterFormatError)
198202
if mu_vecs.shape[-1] != self.c_degree:
199203
raise(ParameterFormatError(
200-
"mu_vecs.shape[-1] must coincide with self.c_degree:"
201-
+f"mu_vecs.shape[-1]={mu_vecs.shape[-1]}, self.c_degree={self.c_degree}"))
204+
"mu_vecs.shape[-1] must coincide with self.c_degree: "
205+
+f"mu_vecs.shape[-1] = {mu_vecs.shape[-1]}, self.c_degree = {self.c_degree}"))
202206
self.mu_vecs[:] = mu_vecs
203207

204208
if lambda_mats is not None:
205209
_check.pos_def_sym_mats(lambda_mats,'lambda_mats',ParameterFormatError)
206210
if lambda_mats.shape[-1] != self.c_degree:
207211
raise(ParameterFormatError(
208212
"lambda_mats.shape[-1] and lambda_mats.shape[-2] must coincide with self.c_degree:"
209-
+f"lambda_mats.shape[-1]={lambda_mats.shape[-1]}, lambda_mats.shape[-2]={lambda_mats.shape[-2]}, self.c_degree={self.c_degree}"))
213+
+f"lambda_mats.shape[-1] and lambda_mats.shape[-2] = {lambda_mats.shape[-1]}, self.c_degree = {self.c_degree}"))
210214
self.lambda_mats[:] = lambda_mats
211215

212216
def get_params(self):
@@ -504,8 +508,8 @@ def set_h0_params(
504508
_check.float_vecs(h0_m_vecs,'h0_m_vecs',ParameterFormatError)
505509
if h0_m_vecs.shape[-1] != self.c_degree:
506510
raise(ParameterFormatError(
507-
"h0_m_vecs.shape[-1] must coincide with self.c_degree:"
508-
+f"h0_m_vecs.shape[-1]={h0_m_vecs.shape[-1]}, self.c_degree={self.c_degree}"))
511+
"h0_m_vecs.shape[-1] must coincide with self.c_degree: "
512+
+f"h0_m_vecs.shape[-1] = {h0_m_vecs.shape[-1]}, self.c_degree = {self.c_degree}"))
509513
self.h0_m_vecs[:] = h0_m_vecs
510514

511515
if h0_kappas is not None:
@@ -516,15 +520,16 @@ def set_h0_params(
516520
_check.pos_floats(h0_nus,'h0_nus',ParameterFormatError)
517521
if np.any(h0_nus <= self.c_degree - 1):
518522
raise(ParameterFormatError(
519-
"c_degree must be smaller than h0_nus + 1"))
523+
"All the values of h0_nus must be greater than self.c_degree - 1: "
524+
+ f"self.c_degree = {self.c_degree}, h0_nus = {h0_nus}"))
520525
self.h0_nus[:] = h0_nus
521526

522527
if h0_w_mats is not None:
523528
_check.pos_def_sym_mats(h0_w_mats,'h0_w_mats',ParameterFormatError)
524529
if h0_w_mats.shape[-1] != self.c_degree:
525530
raise(ParameterFormatError(
526-
"h0_w_mats.shape[-1] and h0_w_mats.shape[-2] must coincide with self.c_degree:"
527-
+f"h0_w_mats.shape[-1]={h0_w_mats.shape[-1]}, h0_w_mats.shape[-2]={h0_w_mats.shape[-2]}, self.c_degree={self.c_degree}"))
531+
"h0_w_mats.shape[-1] and h0_w_mats.shape[-2] must coincide with self.c_degree: "
532+
+f"h0_w_mats.shape[-1] and h0_w_mats.shape[-2] = {h0_w_mats.shape[-1]}, self.c_degree = {self.c_degree}"))
528533
self.h0_w_mats[:] = h0_w_mats
529534
self.h0_w_mats_inv[:] = np.linalg.inv(self.h0_w_mats)
530535

@@ -580,8 +585,8 @@ def set_hn_params(
580585
_check.float_vecs(hn_m_vecs,'hn_m_vecs',ParameterFormatError)
581586
if hn_m_vecs.shape[-1] != self.c_degree:
582587
raise(ParameterFormatError(
583-
"hn_m_vecs.shape[-1] must coincide with self.c_degree:"
584-
+f"hn_m_vecs.shape[-1]={hn_m_vecs.shape[-1]}, self.c_degree={self.c_degree}"))
588+
"hn_m_vecs.shape[-1] must coincide with self.c_degree: "
589+
+f"hn_m_vecs.shape[-1] = {hn_m_vecs.shape[-1]}, self.c_degree = {self.c_degree}"))
585590
self.hn_m_vecs[:] = hn_m_vecs
586591

587592
if hn_kappas is not None:
@@ -592,15 +597,16 @@ def set_hn_params(
592597
_check.pos_floats(hn_nus,'hn_nus',ParameterFormatError)
593598
if np.any(hn_nus <= self.c_degree - 1):
594599
raise(ParameterFormatError(
595-
"c_degree must be smaller than hn_nus + 1"))
600+
"All the values of hn_nus must be greater than self.c_degree - 1: "
601+
+ f"self.c_degree = {self.c_degree}, hn_nus = {hn_nus}"))
596602
self.hn_nus[:] = hn_nus
597603

598604
if hn_w_mats is not None:
599605
_check.pos_def_sym_mats(hn_w_mats,'hn_w_mats',ParameterFormatError)
600606
if hn_w_mats.shape[-1] != self.c_degree:
601607
raise(ParameterFormatError(
602-
"hn_w_mats.shape[-1] and hn_w_mats.shape[-2] must coincide with self.c_degree:"
603-
+f"hn_w_mats.shape[-1]={hn_w_mats.shape[-1]}, hn_w_mats.shape[-2]={hn_w_mats.shape[-2]}, self.c_degree={self.c_degree}"))
608+
"hn_w_mats.shape[-1] and hn_w_mats.shape[-2] must coincide with self.c_degree: "
609+
+f"hn_w_mats.shape[-1] and hn_w_mats.shape[-2] = {hn_w_mats.shape[-1]}, self.c_degree = {self.c_degree}"))
604610
self.hn_w_mats[:] = hn_w_mats
605611
self.hn_w_mats_inv[:] = np.linalg.inv(self.hn_w_mats)
606612

0 commit comments

Comments
 (0)