@@ -76,50 +76,43 @@ def set_params(
76
76
positive definite symetric matrices
77
77
"""
78
78
79
+ # [Check values]
80
+ tmp_pi_vec = None if pi_vec is None else _check .float_vec_sum_1 (pi_vec , "pi_vec" , ParameterFormatError )
81
+ tmp_a_mat = None if a_mat is None else _check .float_vec_sum_1 (a_mat , "a_mat" , ParameterFormatError , ndim = 2 , sum_axis = 1 )
82
+ tmp_mu_vecs = None if mu_vecs is None else _check .float_vecs (mu_vecs , "mu_vecs" , ParameterFormatError )
83
+ tmp_lambda_mats = None if lambda_mats is None else _check .float_vecs (lambda_mats , "lambda_mats" , ParameterFormatError )
84
+
79
85
# [Dimension value consistency]
80
- if pi_vec is not None :
81
- if pi_vec .shape [0 ] != self .c_num_classes :
82
- raise (ParameterFormatError (
83
- "pi_vec.shape[0] must coincide with self.c_num_classes:"
84
- + f"pi_vec.shape[0]={ pi_vec .shape [0 ]} , self.c_num_classes={ self .c_num_classes } " ))
85
- self .pi_vec = _check .float_vec_sum_1 (pi_vec , "pi_vec" , ParameterFormatError )
86
- if a_mat is not None :
87
- if a_mat .shape [0 ] != self .c_num_classes or a_mat .shape [1 ] != self .c_num_classes :
88
- raise (ParameterFormatError (
89
- "a_mat.shape[0] and a_mat.shape[1] must coincide with self.c_num_classes:"
90
- + f"a_mat.shape[0]={ a_mat .shape [0 ]} , a_mat.shape[1]={ a_mat .shape [1 ]} , self.c_num_classes={ self .c_num_classes } " ))
91
- self .a_mat = _check .float_vec_sum_1 (a_mat , "a_mat" , ParameterFormatError , ndim = 2 , sum_axis = 1 )
92
- if mu_vecs is not None :
93
- message = ""
94
- if mu_vecs .shape [0 ] != self .c_degree :
95
- message += "mu_vecs.shape[0] must coincide with self.c_degree:"
96
- + f"mu_vecs.shape[0]={ mu_vecs .shape [0 ]} , self.c_degree={ self .c_degree } "
97
- if len (mu_vecs ) == 2 :
98
- if mu_vecs .shape [1 ] != self .c_num_classes :
99
- message += "mu_vecs.shape[1] must coincide with self.c_num_classes:"
100
- + f"mu_vecs.shape[1]={ mu_vecs .shape [1 ]} , self.c_num_classes={ self .c_num_classes } "
101
- if message != "" :
102
- raise (ParameterFormatError (message ))
103
- self .mu_vecs = _check .float_vecs (mu_vecs , "mu_vecs" , ParameterFormatError )
104
- if len (mu_vecs ) == 1 :
105
- self .mu_vecs = np .broadcast_to (np .reshape (self .mu_vecs , (self .c_degree , 1 )), (self .c_degree , self .c_num_classes ))
106
- if lambda_mats is not None :
107
- message = ""
108
- if lambda_mats .shape [:2 ] != (self .c_degree , self .c_degree ):
109
- message += "lambda_mats.shape[:2] must coincide with (self.c_degree, self.c_degree):"
110
- + f"lambda_mats.shape[:2]={ lambda_mats .shape [:2 ]} , (self.c_degree, self.c_degree)={ (self .c_degree , self .c_degree )} "
111
- if len (lambda_mats ) == 3 :
112
- if lambda_mats .shape [2 ] != self .c_num_classes :
113
- message += "lambda_mats.shape[2] must coincide with self.c_num_classes:"
114
- + f"lambda_mats.shape[2]={ lambda_mats .shape [2 ]} , self.c_num_classes={ self .c_num_classes } "
115
- if message != "" :
116
- raise (ParameterFormatError (message ))
117
- self .lambda_mats = _check .float_vecs (lambda_mats , "lambda_mats" , ParameterFormatError )
118
- if len (self .lambda_mats ) == 2 :
119
- self .lambda_mats = np .broadcast_to (
120
- np .reshape (self .lambda_mats , (self .c_degree ,self .c_degree ,1 )),
121
- (self .c_degree ,self .c_degree ,self .c_num_classes )
122
- )
86
+ if tmp_pi_vec is not None :
87
+ _check .shape_consistency (
88
+ val = pi_vec , val_name = "pi_vec" ,
89
+ correct = [(self .c_num_classes ,)], correct_name = "(self.c_num_classes,)" ,
90
+ exception_class = ParameterFormatError
91
+ )
92
+ self .pi_vec [:] = tmp_pi_vec
93
+ if tmp_a_mat is not None :
94
+ _check .shape_consistency (
95
+ val = a_mat .shape , val_name = "a_mat" ,
96
+ correct = [(self .c_num_classes , self .c_num_classes )], correct_name = "(self.c_num_classes, self.c_num_classes)" ,
97
+ exception_class = ParameterFormatError
98
+ )
99
+ self .a_mat [:] = tmp_a_mat
100
+ if tmp_mu_vecs is not None :
101
+ _check .shape_consistency (
102
+ val = mu_vecs .shape , val_name = "mu_vecs" ,
103
+ correct = [(self .c_degree ,), (self .c_num_classes , self .c_degree )],
104
+ correct_name = "(self.c_degree,), (self.c_num_classes, self.c_degree)" ,
105
+ exception_class = ParameterFormatError
106
+ )
107
+ self .mu_vecs [:] = tmp_mu_vecs
108
+ if tmp_lambda_mats is not None :
109
+ _check .shape_consistency (
110
+ val = lambda_mats .shape , val_name = "lambda_mats" ,
111
+ correct = [(self .c_degree ,self .c_degree ), (self .c_num_classes ,self .c_degree ,self .c_degree )],
112
+ correct_name = "(self.c_degree,self.c_degree), (self.c_num_classes,self.c_degree,self.c_degree)" ,
113
+ exception_class = ParameterFormatError
114
+ )
115
+ self .lambda_mats [:] = tmp_lambda_mats
123
116
124
117
def set_h_params (
125
118
self ,
0 commit comments