@@ -537,7 +537,9 @@ def __init__(
537
537
self .hn_w_mats = np .empty ([self .c_num_classes ,self .c_degree ,self .c_degree ])
538
538
self .hn_w_mats_inv = np .empty ([self .c_num_classes ,self .c_degree ,self .c_degree ])
539
539
540
+ self ._length = 0
540
541
self ._ln_rho = None
542
+ self ._rho = None
541
543
self .alpha_vecs = None
542
544
self .beta_vecs = None
543
545
self .gamma_vecs = None
@@ -901,11 +903,203 @@ def calc_vl(self):
901
903
+ self ._vl_q_pi
902
904
+ self ._vl_q_a
903
905
+ self ._vl_q_mu_lambda )
906
+
907
+ def _init_fb_params (self ):
908
+ self ._ln_rho [:] = 0.0
909
+ self ._rho [:] = 1.0
910
+ self .alpha_vecs [:] = 1 / self .c_num_classes
911
+ self .beta_vecs [:] = 1.0
912
+ self .gamma_vecs [:] = 1 / self .c_num_classes
913
+ self .xi_mats [:] = 1 / (self .c_num_classes ** 2 )
914
+ self .xi_mats [0 ] = 0.0
915
+ self ._cs [:] = 1.0
916
+
917
+ def _init_random_responsibility (self ,x ):
918
+ self .xi_mats [:] = self .rng .dirichlet (np .ones (self .c_num_classes ** 2 ),self .xi_mats .shape [0 ]).reshape (self .xi_mats .shape )
919
+ self .xi_mats [0 ] = 0.0
920
+ self .gamma_vecs [:] = self .xi_mats .sum (axis = 1 )
921
+ self .gamma_vecs [0 ] = self .xi_mats [1 ].sum (axis = 1 )
922
+ self ._calc_n_m_x_bar_s (x )
923
+
924
+ def _init_subsampling (self ,x ):
925
+ _size = int (np .sqrt (self ._length ))
926
+ for k in range (self .c_num_classes ):
927
+ _subsample = self .rng .choice (x ,size = _size ,replace = False ,axis = 0 ,shuffle = False )
928
+ self .hn_m_vecs [k ] = _subsample .sum (axis = 0 ) / _size
929
+ self .hn_w_mats_inv [k ] = ((_subsample - self .hn_m_vecs [k ]).T
930
+ @ (_subsample - self .hn_m_vecs [k ])
931
+ / _size * self .hn_nus [k ]
932
+ + np .identity (self .c_degree ) * 1.0E-5 ) # avoid singular matrix
933
+ self .hn_w_mats [k ] = np .linalg .inv (self .hn_w_mats_inv [k ])
934
+ self ._calc_q_lambda_char ()
935
+
936
+ def _update_q_mu_lambda (self ):
937
+ self .hn_kappas [:] = self .h0_kappas + self .ns
938
+ self .hn_m_vecs [:] = (self .h0_kappas [:,np .newaxis ] * self .h0_m_vecs
939
+ + self .ns [:,np .newaxis ] * self .x_bar_vecs ) / self .hn_kappas [:,np .newaxis ]
940
+ self .hn_nus [:] = self .h0_nus + self .ns
941
+ self .hn_w_mats_inv [:] = (self .h0_w_mats_inv
942
+ + self .ns [:,np .newaxis ,np .newaxis ] * self .s_mats
943
+ + (self .h0_kappas * self .ns / self .hn_kappas )[:,np .newaxis ,np .newaxis ]
944
+ * ((self .x_bar_vecs - self .h0_m_vecs )[:,:,np .newaxis ]
945
+ @ (self .x_bar_vecs - self .h0_m_vecs )[:,np .newaxis ,:])
946
+ )
947
+ self .hn_w_mats [:] = np .linalg .inv (self .hn_w_mats_inv )
948
+ self ._calc_q_lambda_char ()
904
949
905
- def update_posterior ():
950
+ def _update_q_pi (self ):
951
+ self .hn_eta_vec [:] = self .h0_eta_vec + self .ns
952
+ self ._calc_q_pi_char ()
953
+
954
+ def _update_q_a (self ):
955
+ self .hn_zeta_vecs [:] = self .h0_zeta_vecs + self .ms
956
+ self ._calc_q_a_char ()
957
+
958
+ def _calc_rho (self ,x ):
959
+ self ._ln_rho [:] = ((self ._e_ln_lambda_dets
960
+ - self .c_degree * np .log (2 * np .pi )
961
+ - self .c_degree / self .hn_kappas
962
+ - ((x [:,np .newaxis ,:]- self .hn_m_vecs )[:,:,np .newaxis ,:]
963
+ @ self ._e_lambda_mats
964
+ @ (x [:,np .newaxis ,:]- self .hn_m_vecs )[:,:,:,np .newaxis ]
965
+ )[:,:,0 ,0 ]
966
+ ) / 2.0
967
+ )
968
+ self ._rho [:] = np .exp (self ._ln_rho )
969
+
970
+ def _forward (self ):
971
+ self .alpha_vecs [0 ] = self ._rho [0 ] * self ._pi_tilde_vec
972
+ self ._cs [0 ] = self .alpha_vecs [0 ].sum ()
973
+ self .alpha_vecs [0 ] /= self ._cs [0 ]
974
+ for i in range (1 ,self ._length ):
975
+ self .alpha_vecs [i ] = self ._rho [i ] * (self .alpha_vecs [i - 1 ] @ self ._a_tilde_mat )
976
+ self ._cs [i ] = self .alpha_vecs [i ].sum ()
977
+ self .alpha_vecs [i ] /= self ._cs [i ]
978
+
979
+ def _backward (self ):
980
+ for i in range (self ._length - 2 ,- 1 ,- 1 ):
981
+ self .beta_vecs [i ] = self ._a_tilde_mat @ (self ._rho [i + 1 ] * self .beta_vecs [i + 1 ])
982
+ self .beta_vecs [i ] /= self ._cs [i + 1 ]
983
+
984
+ def _update_gamma (self ):
985
+ self .gamma_vecs [:] = self .alpha_vecs * self .beta_vecs
986
+
987
+ def _update_xi (self ):
988
+ self .xi_mats [1 :,:,:] = self .alpha_vecs [:- 1 ,:,np .newaxis ] * self ._rho [1 :,np .newaxis ,:] * self ._a_tilde_mat [np .newaxis ,:,:] * self .beta_vecs [1 :,np .newaxis ,:]
989
+ self .xi_mats [1 :,:,:] /= self ._cs [1 :,np .newaxis ,np .newaxis ]
990
+
991
+ def _update_q_z (self ,x ):
992
+ self ._calc_rho (x )
993
+ self ._forward ()
994
+ self ._backward ()
995
+ self ._update_gamma ()
996
+ self ._update_xi ()
997
+ self ._calc_n_m_x_bar_s (x )
998
+
999
+ def update_posterior (
1000
+ self ,
1001
+ x ,
1002
+ max_itr = 100 ,
1003
+ num_init = 10 ,
1004
+ tolerance = 1.0E-8 ,
1005
+ init_type = 'subsampling'
1006
+ ):
906
1007
"""Update the the posterior distribution using traning data.
1008
+
1009
+ Parameters
1010
+ ----------
1011
+ x : numpy.ndarray
1012
+ All the elements must be real number.
1013
+ max_itr : int, optional
1014
+ maximum number of iterations, by default 100
1015
+ num_init : int, optional
1016
+ number of initializations, by default 10
1017
+ tolerance : float, optional
1018
+ convergence criterion of variational lower bound, by default 1.0E-8
1019
+ init_type : str, optional
1020
+ type of initialization, by default 'subsampling'
1021
+ * 'subsampling': for each latent class, extract a subsample whose size is int(np.sqrt(x.shape[0])).
1022
+ and use its mean and covariance matrix as an initial values of hn_m_vecs and hn_lambda_mats.
1023
+ * 'random_responsibility': randomly assign responsibility to gamma_vecs
907
1024
"""
908
- pass
1025
+ _check .float_vecs (x ,'x' ,DataFormatError )
1026
+ if x .shape [- 1 ] != self .c_degree :
1027
+ raise (DataFormatError (
1028
+ "x.shape[-1] must be self.c_degree: "
1029
+ + f"x.shape[-1]={ x .shape [- 1 ]} , self.c_degree={ self .c_degree } " ))
1030
+ x = x .reshape (- 1 ,self .c_degree )
1031
+ self ._length = x .shape [0 ]
1032
+ self ._ln_rho = np .zeros ([self ._length ,self .c_num_classes ])
1033
+ self ._rho = np .ones ([self ._length ,self .c_num_classes ])
1034
+ self .alpha_vecs = np .ones ([self ._length ,self .c_num_classes ])/ self .c_num_classes
1035
+ self .beta_vecs = np .ones ([self ._length ,self .c_num_classes ])
1036
+ self .gamma_vecs = np .ones ([self ._length ,self .c_num_classes ])/ self .c_num_classes
1037
+ self .xi_mats = np .zeros ([self ._length ,self .c_num_classes ,self .c_num_classes ])/ (self .c_num_classes ** 2 )
1038
+ self ._cs = np .ones ([self ._length ])
1039
+
1040
+ tmp_vl = 0.0
1041
+ tmp_eta_vec = np .copy (self .hn_eta_vec )
1042
+ tmp_zeta_vecs = np .copy (self .hn_zeta_vecs )
1043
+ tmp_m_vecs = np .copy (self .hn_m_vecs )
1044
+ tmp_kappas = np .copy (self .hn_kappas )
1045
+ tmp_nus = np .copy (self .hn_nus )
1046
+ tmp_w_mats = np .copy (self .hn_w_mats )
1047
+ tmp_w_mats_inv = np .copy (self .hn_w_mats_inv )
1048
+
1049
+ convergence_flag = True
1050
+ for i in range (num_init ):
1051
+ self ._init_fb_params ()
1052
+ self .reset_hn_params ()
1053
+ if init_type == 'subsampling' :
1054
+ self ._init_subsampling (x )
1055
+ self ._update_q_z (x )
1056
+ elif init_type == 'random_responsibility' :
1057
+ self ._init_random_responsibility (x )
1058
+ else :
1059
+ raise (ValueError (
1060
+ f'init_type={ init_type } is unsupported. '
1061
+ + 'This function supports only '
1062
+ + '"subsampling" and "random_responsibility"' ))
1063
+ self .calc_vl ()
1064
+ print (f'\r { i } . VL: { self .vl } ' ,end = '' )
1065
+ for t in range (max_itr ):
1066
+ vl_before = self .vl
1067
+ self ._update_q_mu_lambda ()
1068
+ self ._update_q_pi ()
1069
+ self ._update_q_a ()
1070
+ self ._update_q_z (x )
1071
+ self .calc_vl ()
1072
+ print (f'\r { i } . VL: { self .vl } t={ t } ' ,end = '' )
1073
+ if np .abs ((self .vl - vl_before )/ vl_before ) < tolerance :
1074
+ convergence_flag = False
1075
+ print (f'(converged)' ,end = '' )
1076
+ break
1077
+ if i == 0 or self .vl > tmp_vl :
1078
+ print ('*' )
1079
+ tmp_vl = self .vl
1080
+ tmp_eta_vec [:] = self .hn_eta_vec
1081
+ tmp_zeta_vecs [:] = self .hn_zeta_vecs
1082
+ tmp_m_vecs [:] = self .hn_m_vecs
1083
+ tmp_kappas [:] = self .hn_kappas
1084
+ tmp_nus [:] = self .hn_nus
1085
+ tmp_w_mats [:] = self .hn_w_mats
1086
+ tmp_w_mats_inv [:] = self .hn_w_mats_inv
1087
+ else :
1088
+ print ('' )
1089
+ if convergence_flag :
1090
+ warnings .warn ("Algorithm has not converged even once." ,ResultWarning )
1091
+
1092
+ self .hn_eta_vec [:] = tmp_eta_vec
1093
+ self .hn_zeta_vecs [:] = tmp_zeta_vecs
1094
+ self .hn_m_vecs [:] = tmp_m_vecs
1095
+ self .hn_kappas [:] = tmp_kappas
1096
+ self .hn_nus [:] = tmp_nus
1097
+ self .hn_w_mats [:] = tmp_w_mats
1098
+ self .hn_w_mats_inv [:] = tmp_w_mats_inv
1099
+ self ._calc_q_pi_char ()
1100
+ self ._calc_q_a_char ()
1101
+ self ._calc_q_lambda_char ()
1102
+ self ._update_q_z (x )
909
1103
910
1104
def estimate_params (self ,loss = "squared" ):
911
1105
"""Estimate the parameter under the given criterion.
0 commit comments