8
8
from scipy .stats import wishart as ss_wishart
9
9
from scipy .stats import multivariate_t as ss_multivariate_t
10
10
from scipy .stats import dirichlet as ss_dirichlet
11
- from scipy .special import gammaln , digamma , xlogy
11
+ from scipy .special import gammaln , digamma , xlogy , logsumexp
12
12
import matplotlib .pyplot as plt
13
13
14
14
from .. import base
@@ -499,15 +499,28 @@ def __init__(
499
499
self .hn_w_mats = np .empty ([self .num_classes ,self .degree ,self .degree ])
500
500
self .hn_w_mats_inv = np .empty ([self .num_classes ,self .degree ,self .degree ])
501
501
502
- # statistics
502
+ self . ln_rho = None
503
503
self .r_vecs = None
504
- self .x_bar_vecs = np .empty ([self .num_classes ,self .degree ])
505
- self .ns = np .empty (self .num_classes )
506
- self .s_mats = np .empty ([self .num_classes ,self .degree ,self .degree ])
507
504
self .e_lambda_mats = np .empty ([self .num_classes ,self .degree ,self .degree ])
508
505
self .e_ln_lambda_dets = np .empty (self .num_classes )
506
+ self .ln_b_hn_w_nus = np .empty (self .num_classes )
509
507
self .e_ln_pi_vec = np .empty (self .num_classes )
510
508
509
+ # statistics
510
+ self .x_bar_vecs = np .empty ([self .num_classes ,self .degree ])
511
+ self .ns = np .empty (self .num_classes )
512
+ self .s_mats = np .empty ([self .num_classes ,self .degree ,self .degree ])
513
+
514
+ # variational lower bound
515
+ self .vl = 0.0
516
+ self .vl_p_x = 0.0
517
+ self .vl_p_z = 0.0
518
+ self .vl_p_pi = 0.0
519
+ self .vl_p_mu_lambda = 0.0
520
+ self .vl_q_z = 0.0
521
+ self .vl_q_pi = 0.0
522
+ self .vl_q_mu_lambda = 0.0
523
+
511
524
# p_params
512
525
self .p_pi_vec = np .empty ([self .num_classes ])
513
526
self .p_mu_vecs = np .empty ([self .num_classes ,self .degree ])
@@ -695,6 +708,21 @@ def reset_hn_params(self):
695
708
self .hn_w_mats [:] = self .h0_w_mats
696
709
self .hn_w_mats_inv = np .linalg .inv (self .hn_w_mats )
697
710
711
+ self .e_lambda_mats [:] = self .hn_nus [:,np .newaxis ,np .newaxis ] * self .hn_w_mats
712
+ self .e_ln_lambda_dets [:] = (
713
+ np .sum (digamma ((self .hn_nus [:,np .newaxis ]- np .arange (self .degree )) / 2.0 ),axis = 1 )
714
+ + self .degree * np .log (2.0 )
715
+ - np .linalg .slogdet (self .hn_w_mats_inv )[1 ]
716
+ )
717
+ self .e_ln_pi_vec [:] = digamma (self .hn_alpha_vec ) - digamma (self .hn_alpha_vec .sum ())
718
+ self .ln_b_hn_w_nus [:] = (
719
+ self .hn_nus * np .linalg .slogdet (self .hn_w_mats_inv )[1 ]
720
+ - self .hn_nus * self .degree * np .log (2.0 )
721
+ - self .degree * (self .degree - 1 )/ 2.0 * np .log (np .pi )
722
+ - np .sum (gammaln ((self .hn_nus [:,np .newaxis ]- np .arange (self .degree )) / 2.0 ),
723
+ axis = 1 ) * 2.0
724
+ ) / 2.0
725
+
698
726
self .calc_pred_dist ()
699
727
700
728
def overwrite_h0_params (self ):
@@ -713,22 +741,8 @@ def overwrite_h0_params(self):
713
741
self .calc_pred_dist ()
714
742
715
743
def calc_vl (self ):
716
- self .e_lambda_mats = self .hn_nus [:,np .newaxis ,np .newaxis ] * self .hn_w_mats
717
- self .e_ln_lambda_dets = (np .sum (digamma ((self .hn_nus [:,np .newaxis ]- np .arange (self .degree )) / 2.0 ),axis = 1 )
718
- + self .degree * np .log (2.0 )
719
- - np .linalg .slogdet (self .hn_w_mats_inv )[1 ])
720
- self .e_ln_pi_vec = digamma (self .hn_alpha_vec ) - digamma (self .hn_alpha_vec .sum ())
721
-
722
- # tentative
723
- self .ns = np .ones (self .num_classes ) * 10
724
- self .s_mats = np .tile (np .identity (self .degree ),[self .num_classes ,1 ,1 ]) * 5
725
- self .r_vecs = np .ones ([20 ,self .degree ])/ self .degree
726
- self .x_bar_vecs = np .ones ([self .num_classes ,self .degree ])
727
-
728
- vl = 0.0
729
-
730
744
# E[ln p(X|Z,mu,Lambda)]
731
- vl + = np .sum (
745
+ self . vl_p_x = np .sum (
732
746
self .ns
733
747
* (self .e_ln_lambda_dets - self .degree / self .hn_kappas
734
748
- (self .s_mats * self .e_lambda_mats ).sum (axis = (1 ,2 ))
@@ -741,63 +755,173 @@ def calc_vl(self):
741
755
) / 2.0
742
756
743
757
# E[ln p(Z|pi)]
744
- vl + = (self .ns * self .e_ln_pi_vec ).sum ()
758
+ self . vl_p_z = (self .ns * self .e_ln_pi_vec ).sum ()
745
759
746
760
# E[ln p(pi)]
747
- vl + = self .LN_C_H0_ALPHA + ((self .h0_alpha_vec - 1 ) * self .e_ln_pi_vec ).sum ()
761
+ self . vl_p_pi = self .LN_C_H0_ALPHA + ((self .h0_alpha_vec - 1 ) * self .e_ln_pi_vec ).sum ()
748
762
749
763
# E[ln p(mu,Lambda)]
750
- vl += np .sum (
751
- self .degree * (np .log (self .h0_kappas ) - np .log (2 * np .pi ) - self .h0_kappas / self .hn_kappas )
752
- - ((self .hn_m_vecs - self .h0_m_vecs )[:,np .newaxis ,:]
753
- @ self .e_lambda_mats
754
- @ (self .hn_m_vecs - self .h0_m_vecs )[:,:,np .newaxis ])[:,0 ,0 ]
764
+ self .vl_p_mu_lambda = np .sum (
765
+ self .degree * (np .log (self .h0_kappas ) - np .log (2 * np .pi )
766
+ - self .h0_kappas / self .hn_kappas )
767
+ - self .h0_kappas * ((self .hn_m_vecs - self .h0_m_vecs )[:,np .newaxis ,:]
768
+ @ self .e_lambda_mats
769
+ @ (self .hn_m_vecs - self .h0_m_vecs )[:,:,np .newaxis ])[:,0 ,0 ]
755
770
+ 2.0 * self .LN_B_H0_W_NUS
756
- + (self .h0_nus - self .degree ) / 2.0 * self .e_ln_lambda_dets
757
- - np .sum (self .h0_w_mats_inv * self .hn_w_mats ,axis = (1 ,2 ))
771
+ + (self .h0_nus - self .degree ) * self .e_ln_lambda_dets
772
+ - np .sum (self .h0_w_mats_inv * self .e_lambda_mats ,axis = (1 ,2 ))
758
773
) / 2.0
759
774
760
775
# E[ln q(Z|pi)]
761
- vl -= np .sum (xlogy (self .r_vecs ,self .r_vecs ))
776
+ self . vl_q_z = - np .sum (xlogy (self .r_vecs ,self .r_vecs ))
762
777
763
778
# E[ln q(pi)]
764
- vl + = ss_dirichlet .entropy (self .hn_alpha_vec )
779
+ self . vl_q_pi = ss_dirichlet .entropy (self .hn_alpha_vec )
765
780
766
781
# E[ln q(mu,Lambda)]
767
- vl + = np .sum (
782
+ self . vl_q_mu_lambda = np .sum (
768
783
+ self .degree * (1.0 + np .log (2.0 * np .pi ) - np .log (self .hn_kappas ))
769
- - self .LN_B_H0_W_NUS * 2.0
784
+ - self .ln_b_hn_w_nus * 2.0
770
785
- (self .hn_nus - self .degree )* self .e_ln_lambda_dets
771
786
+ self .hn_nus * self .degree
772
787
) / 2.0
773
788
774
- return vl
789
+ # print(self.vl_p_x,
790
+ # self.vl_p_z,
791
+ # self.vl_p_pi,
792
+ # self.vl_p_mu_lambda,
793
+ # self.vl_q_z,
794
+ # self.vl_q_pi,
795
+ # self.vl_q_mu_lambda,
796
+ # )
797
+
798
+ self .vl = (self .vl_p_x
799
+ + self .vl_p_z
800
+ + self .vl_p_pi
801
+ + self .vl_p_mu_lambda
802
+ + self .vl_q_z
803
+ + self .vl_q_pi
804
+ + self .vl_q_mu_lambda )
805
+
806
+ def _calc_statistics (self ,x ):
807
+ self .ns [:] = self .r_vecs .sum (axis = 0 )
808
+ self .x_bar_vecs [:] = (self .r_vecs [:,:,np .newaxis ] * x [:,np .newaxis ,:]).sum (axis = 0 ) / self .ns [:,np .newaxis ]
809
+ self .s_mats [:] = np .sum (self .r_vecs [:,:,np .newaxis ,np .newaxis ]
810
+ * ((x [:,np .newaxis ,:] - self .x_bar_vecs )[:,:,:,np .newaxis ]
811
+ @ (x [:,np .newaxis ,:] - self .x_bar_vecs )[:,:,np .newaxis ,:]),
812
+ axis = 0 ) / self .ns [:,np .newaxis ,np .newaxis ]
813
+
814
+ def _init_q_z (self ):
815
+ self .r_vecs [:] = self .rng .dirichlet (np .ones (self .num_classes ),self .r_vecs .shape [0 ])
816
+
817
+ def _update_q_pi (self ):
818
+ self .hn_alpha_vec [:] = self .h0_alpha_vec + self .ns
819
+ self .e_ln_pi_vec [:] = digamma (self .hn_alpha_vec ) - digamma (self .hn_alpha_vec .sum ())
820
+
821
+ def _update_q_mu_lambda (self ):
822
+ self .hn_kappas [:] = self .h0_kappas + self .ns
823
+ self .hn_m_vecs [:] = (self .h0_kappas [:,np .newaxis ] * self .h0_m_vecs
824
+ + self .ns [:,np .newaxis ] * self .x_bar_vecs ) / self .hn_kappas [:,np .newaxis ]
825
+ self .hn_nus [:] = self .h0_nus + self .ns
826
+ self .hn_w_mats_inv [:] = (self .h0_w_mats_inv
827
+ + self .ns [:,np .newaxis ,np .newaxis ] * self .s_mats
828
+ + (self .h0_kappas * self .ns / self .hn_kappas )[:,np .newaxis ,np .newaxis ]
829
+ * ((self .x_bar_vecs - self .h0_m_vecs )[:,:,np .newaxis ]
830
+ @ (self .x_bar_vecs - self .h0_m_vecs )[:,np .newaxis ,:])
831
+ )
832
+ self .hn_w_mats [:] = np .linalg .inv (self .hn_w_mats_inv )
833
+ self .e_lambda_mats [:] = self .hn_nus [:,np .newaxis ,np .newaxis ] * self .hn_w_mats
834
+ self .e_ln_lambda_dets [:] = (np .sum (digamma ((self .hn_nus [:,np .newaxis ]- np .arange (self .degree )) / 2.0 ),axis = 1 )
835
+ + self .degree * np .log (2.0 )
836
+ - np .linalg .slogdet (self .hn_w_mats_inv )[1 ])
837
+ self .ln_b_hn_w_nus [:] = (
838
+ self .hn_nus * np .linalg .slogdet (self .hn_w_mats_inv )[1 ]
839
+ - self .hn_nus * self .degree * np .log (2.0 )
840
+ - self .degree * (self .degree - 1 )/ 2.0 * np .log (np .pi )
841
+ - np .sum (gammaln ((self .hn_nus [:,np .newaxis ]- np .arange (self .degree )) / 2.0 ),
842
+ axis = 1 ) * 2.0
843
+ ) / 2.0
775
844
776
- def update_posterior (self ,x ):
777
- pass
778
- # """Update the hyperparameters of the posterior distribution using traning data.
845
+ def _update_q_z (self ,x ):
846
+ self .ln_rho [:] = (self .e_ln_pi_vec
847
+ + (self .e_ln_lambda_dets
848
+ - self .degree * np .log (2 * np .pi )
849
+ - self .degree / self .hn_kappas
850
+ - ((x [:,np .newaxis ,:]- self .hn_m_vecs )[:,:,np .newaxis ,:]
851
+ @ self .e_lambda_mats
852
+ @ (x [:,np .newaxis ,:]- self .hn_m_vecs )[:,:,:,np .newaxis ]
853
+ )[:,:,0 ,0 ]
854
+ ) / 2.0
855
+ )
856
+ self .r_vecs [:] = np .exp (self .ln_rho - logsumexp (self .ln_rho ,axis = 1 ,keepdims = True ))
857
+ # self.r_vecs[:] = np.exp(self.ln_rho - self.ln_rho.max(axis=1,keepdims=True))
858
+ # self.r_vecs[:] /= self.r_vecs.sum(axis=1,keepdims=True)
859
+
860
+ def update_posterior (self ,x ,max_itr = 100 ,num_init = 10 ,tolerance = 1.0E-8 ):
861
+ """Update the hyperparameters of the posterior distribution using traning data.
779
862
780
- # Parameters
781
- # ----------
782
- # x : numpy.ndarray
783
- # All the elements must be real number.
784
- # """
785
- # _check.float_vecs(x,'x',DataFormatError)
786
- # if self.degree > 1 and x.shape[-1] != self.degree:
787
- # raise(DataFormatError(f"x.shape[-1] must be degree:{self.degree}"))
788
- # x = x.reshape(-1,self.degree)
789
-
790
- # n = x.shape[0]
791
- # x_bar = x.sum(axis=0)/n
792
-
793
- # self.hn_w_mat_inv[:] = (self.hn_w_mat_inv + (x-x_bar).T @ (x-x_bar)
794
- # + (x_bar - self.hn_m_vec)[:,np.newaxis] @ (x_bar - self.hn_m_vec)[np.newaxis,:]
795
- # * self.hn_kappa * n / (self.hn_kappa + n))
796
- # self.hn_m_vec[:] = (self.hn_kappa*self.hn_m_vec + n*x_bar) / (self.hn_kappa+n)
797
- # self.hn_kappa += n
798
- # self.hn_nu += n
799
-
800
- # self.hn_w_mat[:] = np.linalg.inv(self.hn_w_mat_inv)
863
+ Parameters
864
+ ----------
865
+ x : numpy.ndarray
866
+ All the elements must be real number.
867
+ max_itr : int, optional
868
+ maximum number of iterations, by default 100
869
+ num_init : int, optional
870
+ number of initializations, by default 10
871
+ tolerance : float, optional
872
+ convergence croterion of variational lower bound, by default 1.0E-8
873
+ """
874
+ _check .float_vecs (x ,'x' ,DataFormatError )
875
+ if self .degree > 1 and x .shape [- 1 ] != self .degree :
876
+ raise (DataFormatError (
877
+ "x.shape[-1] must be self.degree: "
878
+ + f"x.shape[-1]={ x .shape [- 1 ]} , self.degree={ self .degree } " ))
879
+ x = x .reshape (- 1 ,self .degree )
880
+ self .ln_rho = np .empty ([x .shape [0 ],self .num_classes ])
881
+ self .r_vecs = np .empty ([x .shape [0 ],self .num_classes ])
882
+
883
+ tmp_vl = 0.0
884
+ tmp_alpha_vec = np .copy (self .hn_alpha_vec )
885
+ tmp_m_vecs = np .copy (self .hn_m_vecs )
886
+ tmp_kappas = np .copy (self .hn_kappas )
887
+ tmp_nus = np .copy (self .hn_nus )
888
+ tmp_w_mats = np .copy (self .hn_w_mats )
889
+ tmp_w_mats_inv = np .copy (self .hn_w_mats_inv )
890
+
891
+ for i in range (num_init ):
892
+ self ._init_q_z ()
893
+ self ._calc_statistics (x )
894
+ self .calc_vl ()
895
+ print (f'\r { i } . VL: { self .vl } ' ,end = '' )
896
+ for t in range (max_itr ):
897
+ vl_before = self .vl
898
+
899
+ self ._update_q_mu_lambda ()
900
+ self ._update_q_pi ()
901
+ self ._update_q_z (x )
902
+ self ._calc_statistics (x )
903
+ self .calc_vl ()
904
+ print (f'\r { i } . VL: { self .vl } ' ,end = '' )
905
+ if np .abs ((self .vl - vl_before )/ vl_before ) < tolerance :
906
+ break
907
+ if i == 0 or self .vl > tmp_vl :
908
+ print ('*' )
909
+ tmp_vl = self .vl
910
+ tmp_alpha_vec [:] = self .hn_alpha_vec
911
+ tmp_m_vecs [:] = self .hn_m_vecs
912
+ tmp_kappas [:] = self .hn_kappas
913
+ tmp_nus [:] = self .hn_nus
914
+ tmp_w_mats [:] = self .hn_w_mats
915
+ tmp_w_mats_inv [:] = self .hn_w_mats_inv
916
+ else :
917
+ print ('' )
918
+
919
+ self .hn_alpha_vec [:] = tmp_alpha_vec
920
+ self .hn_m_vecs [:] = tmp_m_vecs
921
+ self .hn_kappas [:] = tmp_kappas
922
+ self .hn_nus [:] = tmp_nus
923
+ self .hn_w_mats [:] = tmp_w_mats
924
+ self .hn_w_mats_inv [:] = tmp_w_mats_inv
801
925
802
926
def estimate_params (self ,loss = "squared" ):
803
927
pass
0 commit comments