Skip to content

Commit 07ec7e4

Browse files
committed
Add VB alg
1 parent 151f00a commit 07ec7e4

File tree

3 files changed

+211
-10
lines changed

3 files changed

+211
-10
lines changed

bayesml/hiddenmarkovnormal/_hiddenmarkovnormal.py

Lines changed: 196 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,9 @@ def __init__(
537537
self.hn_w_mats = np.empty([self.c_num_classes,self.c_degree,self.c_degree])
538538
self.hn_w_mats_inv = np.empty([self.c_num_classes,self.c_degree,self.c_degree])
539539

540+
self._length = 0
540541
self._ln_rho = None
542+
self._rho = None
541543
self.alpha_vecs = None
542544
self.beta_vecs = None
543545
self.gamma_vecs = None
@@ -901,11 +903,203 @@ def calc_vl(self):
901903
+ self._vl_q_pi
902904
+ self._vl_q_a
903905
+ self._vl_q_mu_lambda)
906+
907+
def _init_fb_params(self):
908+
self._ln_rho[:] = 0.0
909+
self._rho[:] = 1.0
910+
self.alpha_vecs[:] = 1/self.c_num_classes
911+
self.beta_vecs[:] = 1.0
912+
self.gamma_vecs[:] = 1/self.c_num_classes
913+
self.xi_mats[:] = 1/(self.c_num_classes**2)
914+
self.xi_mats[0] = 0.0
915+
self._cs[:] = 1.0
916+
917+
def _init_random_responsibility(self,x):
918+
self.xi_mats[:] = self.rng.dirichlet(np.ones(self.c_num_classes**2),self.xi_mats.shape[0]).reshape(self.xi_mats.shape)
919+
self.xi_mats[0] = 0.0
920+
self.gamma_vecs[:] = self.xi_mats.sum(axis=1)
921+
self.gamma_vecs[0] = self.xi_mats[1].sum(axis=1)
922+
self._calc_n_m_x_bar_s(x)
923+
924+
def _init_subsampling(self,x):
925+
_size = int(np.sqrt(self._length))
926+
for k in range(self.c_num_classes):
927+
_subsample = self.rng.choice(x,size=_size,replace=False,axis=0,shuffle=False)
928+
self.hn_m_vecs[k] = _subsample.sum(axis=0) / _size
929+
self.hn_w_mats_inv[k] = ((_subsample - self.hn_m_vecs[k]).T
930+
@ (_subsample - self.hn_m_vecs[k])
931+
/ _size * self.hn_nus[k]
932+
+ np.identity(self.c_degree) * 1.0E-5) # avoid singular matrix
933+
self.hn_w_mats[k] = np.linalg.inv(self.hn_w_mats_inv[k])
934+
self._calc_q_lambda_char()
935+
936+
def _update_q_mu_lambda(self):
937+
self.hn_kappas[:] = self.h0_kappas + self.ns
938+
self.hn_m_vecs[:] = (self.h0_kappas[:,np.newaxis] * self.h0_m_vecs
939+
+ self.ns[:,np.newaxis] * self.x_bar_vecs) / self.hn_kappas[:,np.newaxis]
940+
self.hn_nus[:] = self.h0_nus + self.ns
941+
self.hn_w_mats_inv[:] = (self.h0_w_mats_inv
942+
+ self.ns[:,np.newaxis,np.newaxis] * self.s_mats
943+
+ (self.h0_kappas * self.ns / self.hn_kappas)[:,np.newaxis,np.newaxis]
944+
* ((self.x_bar_vecs - self.h0_m_vecs)[:,:,np.newaxis]
945+
@ (self.x_bar_vecs - self.h0_m_vecs)[:,np.newaxis,:])
946+
)
947+
self.hn_w_mats[:] = np.linalg.inv(self.hn_w_mats_inv)
948+
self._calc_q_lambda_char()
904949

905-
def update_posterior():
950+
def _update_q_pi(self):
951+
self.hn_eta_vec[:] = self.h0_eta_vec + self.ns
952+
self._calc_q_pi_char()
953+
954+
def _update_q_a(self):
955+
self.hn_zeta_vecs[:] = self.h0_zeta_vecs + self.ms
956+
self._calc_q_a_char()
957+
958+
def _calc_rho(self,x):
959+
self._ln_rho[:] = ((self._e_ln_lambda_dets
960+
- self.c_degree * np.log(2*np.pi)
961+
- self.c_degree / self.hn_kappas
962+
- ((x[:,np.newaxis,:]-self.hn_m_vecs)[:,:,np.newaxis,:]
963+
@ self._e_lambda_mats
964+
@ (x[:,np.newaxis,:]-self.hn_m_vecs)[:,:,:,np.newaxis]
965+
)[:,:,0,0]
966+
) / 2.0
967+
)
968+
self._rho[:] = np.exp(self._ln_rho)
969+
970+
def _forward(self):
971+
self.alpha_vecs[0] = self._rho[0] * self._pi_tilde_vec
972+
self._cs[0] = self.alpha_vecs[0].sum()
973+
self.alpha_vecs[0] /= self._cs[0]
974+
for i in range(1,self._length):
975+
self.alpha_vecs[i] = self._rho[i] * (self.alpha_vecs[i-1] @ self._a_tilde_mat)
976+
self._cs[i] = self.alpha_vecs[i].sum()
977+
self.alpha_vecs[i] /= self._cs[i]
978+
979+
def _backward(self):
980+
for i in range(self._length-2,-1,-1):
981+
self.beta_vecs[i] = self._a_tilde_mat @ (self._rho[i+1] * self.beta_vecs[i+1])
982+
self.beta_vecs[i] /= self._cs[i+1]
983+
984+
def _update_gamma(self):
985+
self.gamma_vecs[:] = self.alpha_vecs * self.beta_vecs
986+
987+
def _update_xi(self):
988+
self.xi_mats[1:,:,:] = self.alpha_vecs[:-1,:,np.newaxis] * self._rho[1:,np.newaxis,:] * self._a_tilde_mat[np.newaxis,:,:] * self.beta_vecs[1:,np.newaxis,:]
989+
self.xi_mats[1:,:,:] /= self._cs[1:,np.newaxis,np.newaxis]
990+
991+
def _update_q_z(self,x):
992+
self._calc_rho(x)
993+
self._forward()
994+
self._backward()
995+
self._update_gamma()
996+
self._update_xi()
997+
self._calc_n_m_x_bar_s(x)
998+
999+
def update_posterior(
1000+
self,
1001+
x,
1002+
max_itr=100,
1003+
num_init=10,
1004+
tolerance=1.0E-8,
1005+
init_type='subsampling'
1006+
):
9061007
"""Update the the posterior distribution using traning data.
1008+
1009+
Parameters
1010+
----------
1011+
x : numpy.ndarray
1012+
All the elements must be real number.
1013+
max_itr : int, optional
1014+
maximum number of iterations, by default 100
1015+
num_init : int, optional
1016+
number of initializations, by default 10
1017+
tolerance : float, optional
1018+
convergence criterion of variational lower bound, by default 1.0E-8
1019+
init_type : str, optional
1020+
type of initialization, by default 'subsampling'
1021+
* 'subsampling': for each latent class, extract a subsample whose size is int(np.sqrt(x.shape[0])).
1022+
and use its mean and covariance matrix as an initial values of hn_m_vecs and hn_lambda_mats.
1023+
* 'random_responsibility': randomly assign responsibility to gamma_vecs
9071024
"""
908-
pass
1025+
_check.float_vecs(x,'x',DataFormatError)
1026+
if x.shape[-1] != self.c_degree:
1027+
raise(DataFormatError(
1028+
"x.shape[-1] must be self.c_degree: "
1029+
+ f"x.shape[-1]={x.shape[-1]}, self.c_degree={self.c_degree}"))
1030+
x = x.reshape(-1,self.c_degree)
1031+
self._length = x.shape[0]
1032+
self._ln_rho = np.zeros([self._length,self.c_num_classes])
1033+
self._rho = np.ones([self._length,self.c_num_classes])
1034+
self.alpha_vecs = np.ones([self._length,self.c_num_classes])/self.c_num_classes
1035+
self.beta_vecs = np.ones([self._length,self.c_num_classes])
1036+
self.gamma_vecs = np.ones([self._length,self.c_num_classes])/self.c_num_classes
1037+
self.xi_mats = np.zeros([self._length,self.c_num_classes,self.c_num_classes])/(self.c_num_classes**2)
1038+
self._cs = np.ones([self._length])
1039+
1040+
tmp_vl = 0.0
1041+
tmp_eta_vec = np.copy(self.hn_eta_vec)
1042+
tmp_zeta_vecs = np.copy(self.hn_zeta_vecs)
1043+
tmp_m_vecs = np.copy(self.hn_m_vecs)
1044+
tmp_kappas = np.copy(self.hn_kappas)
1045+
tmp_nus = np.copy(self.hn_nus)
1046+
tmp_w_mats = np.copy(self.hn_w_mats)
1047+
tmp_w_mats_inv = np.copy(self.hn_w_mats_inv)
1048+
1049+
convergence_flag = True
1050+
for i in range(num_init):
1051+
self._init_fb_params()
1052+
self.reset_hn_params()
1053+
if init_type == 'subsampling':
1054+
self._init_subsampling(x)
1055+
self._update_q_z(x)
1056+
elif init_type == 'random_responsibility':
1057+
self._init_random_responsibility(x)
1058+
else:
1059+
raise(ValueError(
1060+
f'init_type={init_type} is unsupported. '
1061+
+ 'This function supports only '
1062+
+ '"subsampling" and "random_responsibility"'))
1063+
self.calc_vl()
1064+
print(f'\r{i}. VL: {self.vl}',end='')
1065+
for t in range(max_itr):
1066+
vl_before = self.vl
1067+
self._update_q_mu_lambda()
1068+
self._update_q_pi()
1069+
self._update_q_a()
1070+
self._update_q_z(x)
1071+
self.calc_vl()
1072+
print(f'\r{i}. VL: {self.vl} t={t} ',end='')
1073+
if np.abs((self.vl-vl_before)/vl_before) < tolerance:
1074+
convergence_flag = False
1075+
print(f'(converged)',end='')
1076+
break
1077+
if i==0 or self.vl > tmp_vl:
1078+
print('*')
1079+
tmp_vl = self.vl
1080+
tmp_eta_vec[:] = self.hn_eta_vec
1081+
tmp_zeta_vecs[:] = self.hn_zeta_vecs
1082+
tmp_m_vecs[:] = self.hn_m_vecs
1083+
tmp_kappas[:] = self.hn_kappas
1084+
tmp_nus[:] = self.hn_nus
1085+
tmp_w_mats[:] = self.hn_w_mats
1086+
tmp_w_mats_inv[:] = self.hn_w_mats_inv
1087+
else:
1088+
print('')
1089+
if convergence_flag:
1090+
warnings.warn("Algorithm has not converged even once.",ResultWarning)
1091+
1092+
self.hn_eta_vec[:] = tmp_eta_vec
1093+
self.hn_zeta_vecs[:] = tmp_zeta_vecs
1094+
self.hn_m_vecs[:] = tmp_m_vecs
1095+
self.hn_kappas[:] = tmp_kappas
1096+
self.hn_nus[:] = tmp_nus
1097+
self.hn_w_mats[:] = tmp_w_mats
1098+
self.hn_w_mats_inv[:] = tmp_w_mats_inv
1099+
self._calc_q_pi_char()
1100+
self._calc_q_a_char()
1101+
self._calc_q_lambda_char()
1102+
self._update_q_z(x)
9091103

9101104
def estimate_params(self,loss="squared"):
9111105
"""Estimate the parameter under the given criterion.
Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
from bayesml import hiddenmarkovnormal
22
import numpy as np
33

4-
model = hiddenmarkovnormal.LearnModel(
5-
c_num_classes=3,
6-
c_degree=1)
7-
# model.visualize_model()
4+
gen_model = hiddenmarkovnormal.GenModel(
5+
c_num_classes=3,
6+
c_degree=1,
7+
mu_vecs=np.array([[5],[0],[-5]]),
8+
a_mat=np.array([[0.95,0.05,0.0],[0.0,0.9,0.1],[0.1,0.0,0.9]])
9+
)
10+
# gen_model.visualize_model()
11+
x,z = gen_model.gen_sample(sample_length=200)
812

9-
print(model._ln_c_h0_eta_vec)
10-
print(model._ln_c_h0_zeta_vecs)
11-
print(model._ln_b_h0_w_nus)
13+
learn_model = hiddenmarkovnormal.LearnModel(
14+
c_num_classes=3,
15+
c_degree=1,
16+
)
17+
learn_model.update_posterior(x)#,init_type='random_responsibility')
18+
# print(learn_model.get_hn_params())

bayesml/hiddenmarkovnormal/hiddenmarkovnormal.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ $$
112112
\end{cases} \\
113113
\beta^{(t+1)} (\boldsymbol{z}_i) &\propto
114114
\begin{cases}
115-
\sum_{\boldsymbol{z}_{i+1}} \left[ \prod_{k=1}^{K} \left( \rho_{i+1,k}^{(t+1)}\right)^{z_{i,k}} \prod_{k=1}^{K}\prod_{j=1}^{K}\left(\tilde{a}^{(t+1)}_{j,k}\right)^{z_{i,j}z_{i+1,k}}\beta^{(t+1)}(\boldsymbol{z}_{i+1})\right] & (i<n)\\
115+
\sum_{\boldsymbol{z}_{i+1}} \left[ \prod_{k=1}^{K} \left( \rho_{i+1,k}^{(t+1)}\right)^{z_{i+1,k}} \prod_{k=1}^{K}\prod_{j=1}^{K}\left(\tilde{a}^{(t+1)}_{j,k}\right)^{z_{i,j}z_{i+1,k}}\beta^{(t+1)}(\boldsymbol{z}_{i+1})\right] & (i<n)\\
116116
1 & (i=n)
117117
\end{cases} \\
118118
q^{(t+1)}(\boldsymbol{z}_i) &\propto \alpha^{(t+1)}(\boldsymbol{z}_i)\beta^{(t+1)}(\boldsymbol{z}_i) \\

0 commit comments

Comments
 (0)