Skip to content

Commit 151f00a

Browse files
committed
Add calc_vl
1 parent bc51400 commit 151f00a

File tree

2 files changed

+167
-7
lines changed

2 files changed

+167
-7
lines changed

bayesml/hiddenmarkovnormal/_hiddenmarkovnormal.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,11 @@ def visualize_model(self,sample_length=200):
423423
alpha=0.3,
424424
ls='',
425425
)
426+
axes.plot(
427+
np.linspace(change_points[i-1],change_points[i],100),
428+
np.ones(100) * self.mu_vecs[np.argmax(latent_vars[change_points[i-1]])],
429+
c='red',
430+
)
426431
axes.plot(np.arange(sample.shape[0]),sample)
427432
axes.set_xlabel("time")
428433
axes.set_ylabel("x")
@@ -519,6 +524,10 @@ def __init__(
519524
self.h0_w_mats = np.tile(np.identity(self.c_degree),[self.c_num_classes,1,1])
520525
self.h0_w_mats_inv = np.linalg.inv(self.h0_w_mats)
521526

527+
self._ln_c_h0_eta_vec = 0.0
528+
self._ln_c_h0_zeta_vecs_sum = 0.0
529+
self._ln_b_h0_w_nus = np.empty(self.c_num_classes)
530+
522531
# hn_params
523532
self.hn_eta_vec = np.empty(self.c_num_classes)
524533
self.hn_zeta_vecs = np.empty([self.c_num_classes,self.c_num_classes])
@@ -528,7 +537,41 @@ def __init__(
528537
self.hn_w_mats = np.empty([self.c_num_classes,self.c_degree,self.c_degree])
529538
self.hn_w_mats_inv = np.empty([self.c_num_classes,self.c_degree,self.c_degree])
530539

540+
self._ln_rho = None
541+
self.alpha_vecs = None
542+
self.beta_vecs = None
543+
self.gamma_vecs = None
544+
self.xi_mats = None
545+
self._cs = None
546+
self._e_lambda_mats = np.empty([self.c_num_classes,self.c_degree,self.c_degree])
547+
self._e_ln_lambda_dets = np.empty(self.c_num_classes)
548+
self._ln_b_hn_w_nus = np.empty(self.c_num_classes)
549+
self._ln_pi_tilde_vec = np.empty(self.c_num_classes)
550+
self._pi_tilde_vec = np.empty(self.c_num_classes)
551+
self._ln_a_tilde_mat = np.empty([self.c_num_classes,self.c_num_classes])
552+
self._a_tilde_mat = np.empty([self.c_num_classes,self.c_num_classes])
553+
self._ln_c_hn_zeta_vecs_sum = 0.0
554+
555+
# statistics
556+
self.x_bar_vecs = np.empty([self.c_num_classes,self.c_degree])
557+
self.ns = np.empty(self.c_num_classes)
558+
self.ms = np.empty([self.c_num_classes,self.c_num_classes])
559+
self.s_mats = np.empty([self.c_num_classes,self.c_degree,self.c_degree])
560+
561+
# variational lower bound
562+
self.vl = 0.0
563+
self._vl_p_x = 0.0
564+
self._vl_p_z = 0.0
565+
self._vl_p_pi = 0.0
566+
self._vl_p_a = 0.0
567+
self._vl_p_mu_lambda = 0.0
568+
self._vl_q_z = 0.0
569+
self._vl_q_pi = 0.0
570+
self._vl_q_a = 0.0
571+
self._vl_q_mu_lambda = 0.0
572+
531573
# p_params
574+
self.p_a_mat = np.ones([self.c_num_classes,self.c_num_classes]) / self.c_num_classes
532575
self.p_mu_vecs = np.empty([self.c_num_classes,self.c_degree])
533576
self.p_nus = np.empty([self.c_num_classes])
534577
self.p_lambda_mats = np.empty([self.c_num_classes,self.c_degree,self.c_degree])
@@ -621,6 +664,7 @@ def set_h0_params(
621664

622665
self.h0_w_mats_inv[:] = np.linalg.inv(self.h0_w_mats)
623666

667+
self._calc_prior_char()
624668
self.reset_hn_params()
625669

626670
def get_h0_params(self):
@@ -721,6 +765,10 @@ def set_hn_params(
721765

722766
self.hn_w_mats_inv[:] = np.linalg.inv(self.hn_w_mats)
723767

768+
self._calc_q_pi_char()
769+
self._calc_q_a_char()
770+
self._calc_q_lambda_char()
771+
724772
self.calc_pred_dist()
725773

726774
def get_hn_params(self):
@@ -743,6 +791,117 @@ def get_hn_params(self):
743791
'hn_nus':self.hn_nus,
744792
'hn_w_mats':self.hn_w_mats}
745793

794+
def _calc_prior_char(self):
795+
self._ln_c_h0_eta_vec = gammaln(self.h0_eta_vec.sum()) - gammaln(self.h0_eta_vec).sum()
796+
self._ln_c_h0_zeta_vecs_sum = np.sum(gammaln(self.h0_zeta_vecs.sum(axis=1)) - gammaln(self.h0_zeta_vecs).sum(axis=1))
797+
self._ln_b_h0_w_nus = (
798+
- self.h0_nus*np.linalg.slogdet(self.h0_w_mats)[1]
799+
- self.h0_nus*self.c_degree*np.log(2.0)
800+
- self.c_degree*(self.c_degree-1)/2.0*np.log(np.pi)
801+
- np.sum(gammaln((self.h0_nus[:,np.newaxis]-np.arange(self.c_degree)) / 2.0),
802+
axis=1) * 2.0
803+
) / 2.0
804+
805+
def _calc_n_m_x_bar_s(self,x):
806+
self.ns[:] = self.gamma_vecs.sum(axis=0)
807+
self.ms[:] = self.xi_mats.sum(axis=0) # xi must be initialized as a zero matrix
808+
self.x_bar_vecs[:] = (self.gamma_vecs[:,:,np.newaxis] * x[:,np.newaxis,:]).sum(axis=0) / self.ns[:,np.newaxis]
809+
self.s_mats[:] = np.sum(self.gamma_vecs[:,:,np.newaxis,np.newaxis]
810+
* ((x[:,np.newaxis,:] - self.x_bar_vecs)[:,:,:,np.newaxis]
811+
@ (x[:,np.newaxis,:] - self.x_bar_vecs)[:,:,np.newaxis,:]),
812+
axis=0) / self.ns[:,np.newaxis,np.newaxis]
813+
814+
def _calc_q_pi_char(self):
815+
self._ln_pi_tilde_vec[:] = digamma(self.hn_eta_vec) - digamma(self.hn_eta_vec.sum())
816+
self._pi_tilde_vec[:] = np.exp(self._ln_pi_tilde_vec)
817+
# self._pi_tilde_vec[:] = np.exp(self._ln_pi_tilde_vec - self._ln_pi_tilde_vec.max())
818+
# self._pi_tilde_vec[:] /= self._pi_tilde_vec.sum()
819+
820+
def _calc_q_a_char(self):
821+
self._ln_a_tilde_mat[:] = digamma(self.hn_zeta_vecs) - digamma(self.hn_zeta_vecs.sum(axis=1,keepdims=True))
822+
self._a_tilde_mat[:] = np.exp(self._ln_a_tilde_mat)
823+
# self._a_tilde_mat[:] = np.exp(self._ln_a_tilde_mat - self._ln_a_tilde_mat.max(axis=1,keepdims=True))
824+
# self._a_tilde_mat[:] /= self._a_tilde_mat.sum(axis=1,keepdims=True)
825+
self._ln_c_hn_zeta_vecs_sum = np.sum(gammaln(self.hn_zeta_vecs.sum(axis=1)) - gammaln(self.hn_zeta_vecs).sum(axis=1))
826+
827+
def _calc_q_lambda_char(self):
828+
self._e_lambda_mats[:] = self.hn_nus[:,np.newaxis,np.newaxis] * self.hn_w_mats
829+
self._e_ln_lambda_dets[:] = (np.sum(digamma((self.hn_nus[:,np.newaxis]-np.arange(self.c_degree)) / 2.0),axis=1)
830+
+ self.c_degree*np.log(2.0)
831+
- np.linalg.slogdet(self.hn_w_mats_inv)[1])
832+
self._ln_b_hn_w_nus[:] = (
833+
self.hn_nus*np.linalg.slogdet(self.hn_w_mats_inv)[1]
834+
- self.hn_nus*self.c_degree*np.log(2.0)
835+
- self.c_degree*(self.c_degree-1)/2.0*np.log(np.pi)
836+
- np.sum(gammaln((self.hn_nus[:,np.newaxis]-np.arange(self.c_degree)) / 2.0),
837+
axis=1) * 2.0
838+
) / 2.0
839+
840+
def calc_vl(self):
841+
# E[ln p(X|Z,mu,Lambda)]
842+
self._vl_p_x = np.sum(
843+
self.ns
844+
* (self._e_ln_lambda_dets - self.c_degree / self.hn_kappas
845+
- (self.s_mats * self._e_lambda_mats).sum(axis=(1,2))
846+
- ((self.x_bar_vecs - self.hn_m_vecs)[:,np.newaxis,:]
847+
@ self._e_lambda_mats
848+
@ (self.x_bar_vecs - self.hn_m_vecs)[:,:,np.newaxis]
849+
)[:,0,0]
850+
- self.c_degree * np.log(2*np.pi)
851+
)
852+
) / 2.0
853+
854+
# E[ln p(Z|pi)]
855+
self._vl_p_z = (self.gamma_vecs[0] * self._ln_pi_tilde_vec).sum() + (self.ms * self._ln_a_tilde_mat).sum()
856+
857+
# E[ln p(pi)]
858+
self._vl_p_pi = self._ln_c_h0_eta_vec + ((self.h0_eta_vec - 1) * self._ln_pi_tilde_vec).sum()
859+
860+
# E[ln p(A)]
861+
self._vl_p_a = self._ln_c_h0_zeta_vecs_sum + ((self.h0_zeta_vecs - 1) * self._ln_a_tilde_mat).sum()
862+
863+
# E[ln p(mu,Lambda)]
864+
self._vl_p_mu_lambda = np.sum(
865+
self.c_degree * (np.log(self.h0_kappas) - np.log(2*np.pi)
866+
- self.h0_kappas/self.hn_kappas)
867+
- self.h0_kappas * ((self.hn_m_vecs - self.h0_m_vecs)[:,np.newaxis,:]
868+
@ self._e_lambda_mats
869+
@ (self.hn_m_vecs - self.h0_m_vecs)[:,:,np.newaxis])[:,0,0]
870+
+ 2.0 * self._ln_b_h0_w_nus
871+
+ (self.h0_nus - self.c_degree) * self._e_ln_lambda_dets
872+
- np.sum(self.h0_w_mats_inv * self._e_lambda_mats,axis=(1,2))
873+
) / 2.0
874+
875+
# E[ln q(Z|pi)]
876+
self._vl_q_z = (-(self.gamma_vecs * self._ln_rho).sum()
877+
-(self.ms * self._ln_a_tilde_mat).sum()
878+
-(self.gamma_vecs[0] * self._ln_pi_tilde_vec).sum()
879+
+np.log(self._cs).sum())
880+
881+
# E[ln q(pi)]
882+
self._vl_q_pi = ss_dirichlet.entropy(self.hn_eta_vec)
883+
884+
# E[ln p(A)]
885+
self._vl_q_a = -self._ln_c_hn_zeta_vecs_sum - ((self.hn_zeta_vecs - 1) * self._ln_a_tilde_mat).sum()
886+
887+
# E[ln q(mu,Lambda)]
888+
self._vl_q_mu_lambda = np.sum(
889+
+ self.c_degree * (1.0 + np.log(2.0*np.pi) - np.log(self.hn_kappas))
890+
- self._ln_b_hn_w_nus * 2.0
891+
- (self.hn_nus-self.c_degree)*self._e_ln_lambda_dets
892+
+ self.hn_nus * self.c_degree
893+
) / 2.0
894+
895+
self.vl = (self._vl_p_x
896+
+ self._vl_p_z
897+
+ self._vl_p_pi
898+
+ self._vl_p_a
899+
+ self._vl_p_mu_lambda
900+
+ self._vl_q_z
901+
+ self._vl_q_pi
902+
+ self._vl_q_a
903+
+ self._vl_q_mu_lambda)
904+
746905
def update_posterior():
747906
"""Update the the posterior distribution using traning data.
748907
"""
Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from bayesml import hiddenmarkovnormal
22
import numpy as np
33

4-
model = hiddenmarkovnormal.GenModel(3,1)
5-
6-
print(model.get_params())
7-
8-
model.set_params(mu_vecs=np.ones([3,1]))
9-
10-
print(model.get_params())
4+
model = hiddenmarkovnormal.LearnModel(
5+
c_num_classes=3,
6+
c_degree=1)
7+
# model.visualize_model()
8+
9+
print(model._ln_c_h0_eta_vec)
10+
print(model._ln_c_h0_zeta_vecs)
11+
print(model._ln_b_h0_w_nus)

0 commit comments

Comments
 (0)