@@ -423,6 +423,11 @@ def visualize_model(self,sample_length=200):
423
423
alpha = 0.3 ,
424
424
ls = '' ,
425
425
)
426
+ axes .plot (
427
+ np .linspace (change_points [i - 1 ],change_points [i ],100 ),
428
+ np .ones (100 ) * self .mu_vecs [np .argmax (latent_vars [change_points [i - 1 ]])],
429
+ c = 'red' ,
430
+ )
426
431
axes .plot (np .arange (sample .shape [0 ]),sample )
427
432
axes .set_xlabel ("time" )
428
433
axes .set_ylabel ("x" )
@@ -519,6 +524,10 @@ def __init__(
519
524
self .h0_w_mats = np .tile (np .identity (self .c_degree ),[self .c_num_classes ,1 ,1 ])
520
525
self .h0_w_mats_inv = np .linalg .inv (self .h0_w_mats )
521
526
527
+ self ._ln_c_h0_eta_vec = 0.0
528
+ self ._ln_c_h0_zeta_vecs_sum = 0.0
529
+ self ._ln_b_h0_w_nus = np .empty (self .c_num_classes )
530
+
522
531
# hn_params
523
532
self .hn_eta_vec = np .empty (self .c_num_classes )
524
533
self .hn_zeta_vecs = np .empty ([self .c_num_classes ,self .c_num_classes ])
@@ -528,7 +537,41 @@ def __init__(
528
537
self .hn_w_mats = np .empty ([self .c_num_classes ,self .c_degree ,self .c_degree ])
529
538
self .hn_w_mats_inv = np .empty ([self .c_num_classes ,self .c_degree ,self .c_degree ])
530
539
540
+ self ._ln_rho = None
541
+ self .alpha_vecs = None
542
+ self .beta_vecs = None
543
+ self .gamma_vecs = None
544
+ self .xi_mats = None
545
+ self ._cs = None
546
+ self ._e_lambda_mats = np .empty ([self .c_num_classes ,self .c_degree ,self .c_degree ])
547
+ self ._e_ln_lambda_dets = np .empty (self .c_num_classes )
548
+ self ._ln_b_hn_w_nus = np .empty (self .c_num_classes )
549
+ self ._ln_pi_tilde_vec = np .empty (self .c_num_classes )
550
+ self ._pi_tilde_vec = np .empty (self .c_num_classes )
551
+ self ._ln_a_tilde_mat = np .empty ([self .c_num_classes ,self .c_num_classes ])
552
+ self ._a_tilde_mat = np .empty ([self .c_num_classes ,self .c_num_classes ])
553
+ self ._ln_c_hn_zeta_vecs_sum = 0.0
554
+
555
+ # statistics
556
+ self .x_bar_vecs = np .empty ([self .c_num_classes ,self .c_degree ])
557
+ self .ns = np .empty (self .c_num_classes )
558
+ self .ms = np .empty ([self .c_num_classes ,self .c_num_classes ])
559
+ self .s_mats = np .empty ([self .c_num_classes ,self .c_degree ,self .c_degree ])
560
+
561
+ # variational lower bound
562
+ self .vl = 0.0
563
+ self ._vl_p_x = 0.0
564
+ self ._vl_p_z = 0.0
565
+ self ._vl_p_pi = 0.0
566
+ self ._vl_p_a = 0.0
567
+ self ._vl_p_mu_lambda = 0.0
568
+ self ._vl_q_z = 0.0
569
+ self ._vl_q_pi = 0.0
570
+ self ._vl_q_a = 0.0
571
+ self ._vl_q_mu_lambda = 0.0
572
+
531
573
# p_params
574
+ self .p_a_mat = np .ones ([self .c_num_classes ,self .c_num_classes ]) / self .c_num_classes
532
575
self .p_mu_vecs = np .empty ([self .c_num_classes ,self .c_degree ])
533
576
self .p_nus = np .empty ([self .c_num_classes ])
534
577
self .p_lambda_mats = np .empty ([self .c_num_classes ,self .c_degree ,self .c_degree ])
@@ -621,6 +664,7 @@ def set_h0_params(
621
664
622
665
self .h0_w_mats_inv [:] = np .linalg .inv (self .h0_w_mats )
623
666
667
+ self ._calc_prior_char ()
624
668
self .reset_hn_params ()
625
669
626
670
def get_h0_params (self ):
@@ -721,6 +765,10 @@ def set_hn_params(
721
765
722
766
self .hn_w_mats_inv [:] = np .linalg .inv (self .hn_w_mats )
723
767
768
+ self ._calc_q_pi_char ()
769
+ self ._calc_q_a_char ()
770
+ self ._calc_q_lambda_char ()
771
+
724
772
self .calc_pred_dist ()
725
773
726
774
def get_hn_params (self ):
@@ -743,6 +791,117 @@ def get_hn_params(self):
743
791
'hn_nus' :self .hn_nus ,
744
792
'hn_w_mats' :self .hn_w_mats }
745
793
794
+ def _calc_prior_char (self ):
795
+ self ._ln_c_h0_eta_vec = gammaln (self .h0_eta_vec .sum ()) - gammaln (self .h0_eta_vec ).sum ()
796
+ self ._ln_c_h0_zeta_vecs_sum = np .sum (gammaln (self .h0_zeta_vecs .sum (axis = 1 )) - gammaln (self .h0_zeta_vecs ).sum (axis = 1 ))
797
+ self ._ln_b_h0_w_nus = (
798
+ - self .h0_nus * np .linalg .slogdet (self .h0_w_mats )[1 ]
799
+ - self .h0_nus * self .c_degree * np .log (2.0 )
800
+ - self .c_degree * (self .c_degree - 1 )/ 2.0 * np .log (np .pi )
801
+ - np .sum (gammaln ((self .h0_nus [:,np .newaxis ]- np .arange (self .c_degree )) / 2.0 ),
802
+ axis = 1 ) * 2.0
803
+ ) / 2.0
804
+
805
+ def _calc_n_m_x_bar_s (self ,x ):
806
+ self .ns [:] = self .gamma_vecs .sum (axis = 0 )
807
+ self .ms [:] = self .xi_mats .sum (axis = 0 ) # xi must be initialized as a zero matrix
808
+ self .x_bar_vecs [:] = (self .gamma_vecs [:,:,np .newaxis ] * x [:,np .newaxis ,:]).sum (axis = 0 ) / self .ns [:,np .newaxis ]
809
+ self .s_mats [:] = np .sum (self .gamma_vecs [:,:,np .newaxis ,np .newaxis ]
810
+ * ((x [:,np .newaxis ,:] - self .x_bar_vecs )[:,:,:,np .newaxis ]
811
+ @ (x [:,np .newaxis ,:] - self .x_bar_vecs )[:,:,np .newaxis ,:]),
812
+ axis = 0 ) / self .ns [:,np .newaxis ,np .newaxis ]
813
+
814
+ def _calc_q_pi_char (self ):
815
+ self ._ln_pi_tilde_vec [:] = digamma (self .hn_eta_vec ) - digamma (self .hn_eta_vec .sum ())
816
+ self ._pi_tilde_vec [:] = np .exp (self ._ln_pi_tilde_vec )
817
+ # self._pi_tilde_vec[:] = np.exp(self._ln_pi_tilde_vec - self._ln_pi_tilde_vec.max())
818
+ # self._pi_tilde_vec[:] /= self._pi_tilde_vec.sum()
819
+
820
+ def _calc_q_a_char (self ):
821
+ self ._ln_a_tilde_mat [:] = digamma (self .hn_zeta_vecs ) - digamma (self .hn_zeta_vecs .sum (axis = 1 ,keepdims = True ))
822
+ self ._a_tilde_mat [:] = np .exp (self ._ln_a_tilde_mat )
823
+ # self._a_tilde_mat[:] = np.exp(self._ln_a_tilde_mat - self._ln_a_tilde_mat.max(axis=1,keepdims=True))
824
+ # self._a_tilde_mat[:] /= self._a_tilde_mat.sum(axis=1,keepdims=True)
825
+ self ._ln_c_hn_zeta_vecs_sum = np .sum (gammaln (self .hn_zeta_vecs .sum (axis = 1 )) - gammaln (self .hn_zeta_vecs ).sum (axis = 1 ))
826
+
827
+ def _calc_q_lambda_char (self ):
828
+ self ._e_lambda_mats [:] = self .hn_nus [:,np .newaxis ,np .newaxis ] * self .hn_w_mats
829
+ self ._e_ln_lambda_dets [:] = (np .sum (digamma ((self .hn_nus [:,np .newaxis ]- np .arange (self .c_degree )) / 2.0 ),axis = 1 )
830
+ + self .c_degree * np .log (2.0 )
831
+ - np .linalg .slogdet (self .hn_w_mats_inv )[1 ])
832
+ self ._ln_b_hn_w_nus [:] = (
833
+ self .hn_nus * np .linalg .slogdet (self .hn_w_mats_inv )[1 ]
834
+ - self .hn_nus * self .c_degree * np .log (2.0 )
835
+ - self .c_degree * (self .c_degree - 1 )/ 2.0 * np .log (np .pi )
836
+ - np .sum (gammaln ((self .hn_nus [:,np .newaxis ]- np .arange (self .c_degree )) / 2.0 ),
837
+ axis = 1 ) * 2.0
838
+ ) / 2.0
839
+
840
+ def calc_vl (self ):
841
+ # E[ln p(X|Z,mu,Lambda)]
842
+ self ._vl_p_x = np .sum (
843
+ self .ns
844
+ * (self ._e_ln_lambda_dets - self .c_degree / self .hn_kappas
845
+ - (self .s_mats * self ._e_lambda_mats ).sum (axis = (1 ,2 ))
846
+ - ((self .x_bar_vecs - self .hn_m_vecs )[:,np .newaxis ,:]
847
+ @ self ._e_lambda_mats
848
+ @ (self .x_bar_vecs - self .hn_m_vecs )[:,:,np .newaxis ]
849
+ )[:,0 ,0 ]
850
+ - self .c_degree * np .log (2 * np .pi )
851
+ )
852
+ ) / 2.0
853
+
854
+ # E[ln p(Z|pi)]
855
+ self ._vl_p_z = (self .gamma_vecs [0 ] * self ._ln_pi_tilde_vec ).sum () + (self .ms * self ._ln_a_tilde_mat ).sum ()
856
+
857
+ # E[ln p(pi)]
858
+ self ._vl_p_pi = self ._ln_c_h0_eta_vec + ((self .h0_eta_vec - 1 ) * self ._ln_pi_tilde_vec ).sum ()
859
+
860
+ # E[ln p(A)]
861
+ self ._vl_p_a = self ._ln_c_h0_zeta_vecs_sum + ((self .h0_zeta_vecs - 1 ) * self ._ln_a_tilde_mat ).sum ()
862
+
863
+ # E[ln p(mu,Lambda)]
864
+ self ._vl_p_mu_lambda = np .sum (
865
+ self .c_degree * (np .log (self .h0_kappas ) - np .log (2 * np .pi )
866
+ - self .h0_kappas / self .hn_kappas )
867
+ - self .h0_kappas * ((self .hn_m_vecs - self .h0_m_vecs )[:,np .newaxis ,:]
868
+ @ self ._e_lambda_mats
869
+ @ (self .hn_m_vecs - self .h0_m_vecs )[:,:,np .newaxis ])[:,0 ,0 ]
870
+ + 2.0 * self ._ln_b_h0_w_nus
871
+ + (self .h0_nus - self .c_degree ) * self ._e_ln_lambda_dets
872
+ - np .sum (self .h0_w_mats_inv * self ._e_lambda_mats ,axis = (1 ,2 ))
873
+ ) / 2.0
874
+
875
+ # E[ln q(Z|pi)]
876
+ self ._vl_q_z = (- (self .gamma_vecs * self ._ln_rho ).sum ()
877
+ - (self .ms * self ._ln_a_tilde_mat ).sum ()
878
+ - (self .gamma_vecs [0 ] * self ._ln_pi_tilde_vec ).sum ()
879
+ + np .log (self ._cs ).sum ())
880
+
881
+ # E[ln q(pi)]
882
+ self ._vl_q_pi = ss_dirichlet .entropy (self .hn_eta_vec )
883
+
884
+ # E[ln p(A)]
885
+ self ._vl_q_a = - self ._ln_c_hn_zeta_vecs_sum - ((self .hn_zeta_vecs - 1 ) * self ._ln_a_tilde_mat ).sum ()
886
+
887
+ # E[ln q(mu,Lambda)]
888
+ self ._vl_q_mu_lambda = np .sum (
889
+ + self .c_degree * (1.0 + np .log (2.0 * np .pi ) - np .log (self .hn_kappas ))
890
+ - self ._ln_b_hn_w_nus * 2.0
891
+ - (self .hn_nus - self .c_degree )* self ._e_ln_lambda_dets
892
+ + self .hn_nus * self .c_degree
893
+ ) / 2.0
894
+
895
+ self .vl = (self ._vl_p_x
896
+ + self ._vl_p_z
897
+ + self ._vl_p_pi
898
+ + self ._vl_p_a
899
+ + self ._vl_p_mu_lambda
900
+ + self ._vl_q_z
901
+ + self ._vl_q_pi
902
+ + self ._vl_q_a
903
+ + self ._vl_q_mu_lambda )
904
+
746
905
def update_posterior ():
747
906
"""Update the the posterior distribution using traning data.
748
907
"""
0 commit comments