@@ -120,8 +120,8 @@ def set_h_params(
120
120
_check .float_vecs (h_m_vecs ,'h_m_vecs' ,ParameterFormatError )
121
121
if h_m_vecs .shape [- 1 ] != self .c_degree :
122
122
raise (ParameterFormatError (
123
- "h_m_vecs.shape[-1] must coincide with self.c_degree:"
124
- + f"h_m_vecs.shape[-1]= { h_m_vecs .shape [- 1 ]} , self.c_degree= { self .c_degree } " ))
123
+ "h_m_vecs.shape[-1] must coincide with self.c_degree: "
124
+ + f"h_m_vecs.shape[-1] = { h_m_vecs .shape [- 1 ]} , self.c_degree = { self .c_degree } " ))
125
125
self .h_m_vecs [:] = h_m_vecs
126
126
127
127
if h_kappas is not None :
@@ -132,16 +132,16 @@ def set_h_params(
132
132
_check .pos_floats (h_nus ,'h_nus' ,ParameterFormatError )
133
133
if np .any (h_nus <= self .c_degree - 1 ):
134
134
raise (ParameterFormatError (
135
- "c_degree must be smaller than h_nus + 1 "
136
- + f"self.c_degree= { self .c_degree } , h_nus= { h_nus } " ))
135
+ "All the values of h_nus must be greater than self.c_degree - 1: "
136
+ + f"self.c_degree = { self .c_degree } , h_nus = { h_nus } " ))
137
137
self .h_nus [:] = h_nus
138
138
139
139
if h_w_mats is not None :
140
140
_check .pos_def_sym_mats (h_w_mats ,'h_w_mats' ,ParameterFormatError )
141
141
if h_w_mats .shape [- 1 ] != self .c_degree :
142
142
raise (ParameterFormatError (
143
- "h_w_mats.shape[-1] and h_w_mats.shape[-2] must coincide with self.c_degree:"
144
- + f"h_w_mats.shape[-1]= { h_w_mats . shape [ - 1 ] } , h_w_mats.shape[-2]= { h_w_mats .shape [- 2 ]} , self.c_degree= { self .c_degree } " ))
143
+ "h_w_mats.shape[-1] and h_w_mats.shape[-2] must coincide with self.c_degree: "
144
+ + f"h_w_mats.shape[-1] and h_w_mats.shape[-2] = { h_w_mats .shape [- 1 ]} , self.c_degree = { self .c_degree } " ))
145
145
self .h_w_mats [:] = h_w_mats
146
146
147
147
def get_h_params (self ):
@@ -191,22 +191,26 @@ def set_params(
191
191
"""
192
192
if pi_vec is not None :
193
193
_check .float_vec_sum_1 (pi_vec ,'pi_vec' ,ParameterFormatError )
194
+ if pi_vec .shape [0 ] != self .c_num_classes :
195
+ raise (ParameterFormatError (
196
+ "pi_vec.shape[0] must coincide with self.c_num_classes: "
197
+ + f"pi_vec.shape[0] = { pi_vec .shape [0 ]} , self.c_num_classes = { self .c_num_classes } " ))
194
198
self .pi_vec [:] = pi_vec
195
199
196
200
if mu_vecs is not None :
197
201
_check .float_vecs (mu_vecs ,'mu_vecs' ,ParameterFormatError )
198
202
if mu_vecs .shape [- 1 ] != self .c_degree :
199
203
raise (ParameterFormatError (
200
- "mu_vecs.shape[-1] must coincide with self.c_degree:"
201
- + f"mu_vecs.shape[-1]= { mu_vecs .shape [- 1 ]} , self.c_degree= { self .c_degree } " ))
204
+ "mu_vecs.shape[-1] must coincide with self.c_degree: "
205
+ + f"mu_vecs.shape[-1] = { mu_vecs .shape [- 1 ]} , self.c_degree = { self .c_degree } " ))
202
206
self .mu_vecs [:] = mu_vecs
203
207
204
208
if lambda_mats is not None :
205
209
_check .pos_def_sym_mats (lambda_mats ,'lambda_mats' ,ParameterFormatError )
206
210
if lambda_mats .shape [- 1 ] != self .c_degree :
207
211
raise (ParameterFormatError (
208
212
"lambda_mats.shape[-1] and lambda_mats.shape[-2] must coincide with self.c_degree:"
209
- + f"lambda_mats.shape[-1]= { lambda_mats . shape [ - 1 ] } , lambda_mats.shape[-2]= { lambda_mats .shape [- 2 ]} , self.c_degree= { self .c_degree } " ))
213
+ + f"lambda_mats.shape[-1] and lambda_mats.shape[-2] = { lambda_mats .shape [- 1 ]} , self.c_degree = { self .c_degree } " ))
210
214
self .lambda_mats [:] = lambda_mats
211
215
212
216
def get_params (self ):
@@ -504,8 +508,8 @@ def set_h0_params(
504
508
_check .float_vecs (h0_m_vecs ,'h0_m_vecs' ,ParameterFormatError )
505
509
if h0_m_vecs .shape [- 1 ] != self .c_degree :
506
510
raise (ParameterFormatError (
507
- "h0_m_vecs.shape[-1] must coincide with self.c_degree:"
508
- + f"h0_m_vecs.shape[-1]= { h0_m_vecs .shape [- 1 ]} , self.c_degree= { self .c_degree } " ))
511
+ "h0_m_vecs.shape[-1] must coincide with self.c_degree: "
512
+ + f"h0_m_vecs.shape[-1] = { h0_m_vecs .shape [- 1 ]} , self.c_degree = { self .c_degree } " ))
509
513
self .h0_m_vecs [:] = h0_m_vecs
510
514
511
515
if h0_kappas is not None :
@@ -516,15 +520,16 @@ def set_h0_params(
516
520
_check .pos_floats (h0_nus ,'h0_nus' ,ParameterFormatError )
517
521
if np .any (h0_nus <= self .c_degree - 1 ):
518
522
raise (ParameterFormatError (
519
- "c_degree must be smaller than h0_nus + 1" ))
523
+ "All the values of h0_nus must be greater than self.c_degree - 1: "
524
+ + f"self.c_degree = { self .c_degree } , h0_nus = { h0_nus } " ))
520
525
self .h0_nus [:] = h0_nus
521
526
522
527
if h0_w_mats is not None :
523
528
_check .pos_def_sym_mats (h0_w_mats ,'h0_w_mats' ,ParameterFormatError )
524
529
if h0_w_mats .shape [- 1 ] != self .c_degree :
525
530
raise (ParameterFormatError (
526
- "h0_w_mats.shape[-1] and h0_w_mats.shape[-2] must coincide with self.c_degree:"
527
- + f"h0_w_mats.shape[-1]= { h0_w_mats . shape [ - 1 ] } , h0_w_mats.shape[-2]= { h0_w_mats .shape [- 2 ]} , self.c_degree= { self .c_degree } " ))
531
+ "h0_w_mats.shape[-1] and h0_w_mats.shape[-2] must coincide with self.c_degree: "
532
+ + f"h0_w_mats.shape[-1] and h0_w_mats.shape[-2] = { h0_w_mats .shape [- 1 ]} , self.c_degree = { self .c_degree } " ))
528
533
self .h0_w_mats [:] = h0_w_mats
529
534
self .h0_w_mats_inv [:] = np .linalg .inv (self .h0_w_mats )
530
535
@@ -580,8 +585,8 @@ def set_hn_params(
580
585
_check .float_vecs (hn_m_vecs ,'hn_m_vecs' ,ParameterFormatError )
581
586
if hn_m_vecs .shape [- 1 ] != self .c_degree :
582
587
raise (ParameterFormatError (
583
- "hn_m_vecs.shape[-1] must coincide with self.c_degree:"
584
- + f"hn_m_vecs.shape[-1]= { hn_m_vecs .shape [- 1 ]} , self.c_degree= { self .c_degree } " ))
588
+ "hn_m_vecs.shape[-1] must coincide with self.c_degree: "
589
+ + f"hn_m_vecs.shape[-1] = { hn_m_vecs .shape [- 1 ]} , self.c_degree = { self .c_degree } " ))
585
590
self .hn_m_vecs [:] = hn_m_vecs
586
591
587
592
if hn_kappas is not None :
@@ -592,15 +597,16 @@ def set_hn_params(
592
597
_check .pos_floats (hn_nus ,'hn_nus' ,ParameterFormatError )
593
598
if np .any (hn_nus <= self .c_degree - 1 ):
594
599
raise (ParameterFormatError (
595
- "c_degree must be smaller than hn_nus + 1" ))
600
+ "All the values of hn_nus must be greater than self.c_degree - 1: "
601
+ + f"self.c_degree = { self .c_degree } , hn_nus = { hn_nus } " ))
596
602
self .hn_nus [:] = hn_nus
597
603
598
604
if hn_w_mats is not None :
599
605
_check .pos_def_sym_mats (hn_w_mats ,'hn_w_mats' ,ParameterFormatError )
600
606
if hn_w_mats .shape [- 1 ] != self .c_degree :
601
607
raise (ParameterFormatError (
602
- "hn_w_mats.shape[-1] and hn_w_mats.shape[-2] must coincide with self.c_degree:"
603
- + f"hn_w_mats.shape[-1]= { hn_w_mats . shape [ - 1 ] } , hn_w_mats.shape[-2]= { hn_w_mats .shape [- 2 ]} , self.c_degree= { self .c_degree } " ))
608
+ "hn_w_mats.shape[-1] and hn_w_mats.shape[-2] must coincide with self.c_degree: "
609
+ + f"hn_w_mats.shape[-1] and hn_w_mats.shape[-2] = { hn_w_mats .shape [- 1 ]} , self.c_degree = { self .c_degree } " ))
604
610
self .hn_w_mats [:] = hn_w_mats
605
611
self .hn_w_mats_inv [:] = np .linalg .inv (self .hn_w_mats )
606
612
0 commit comments