Skip to content

Commit b564e50

Browse files
committed
updating rule of VB
1 parent a081ced commit b564e50

File tree

4 files changed

+220
-62
lines changed

4 files changed

+220
-62
lines changed

bayesml/gaussianmixture/_gaussianmixture.py

Lines changed: 182 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from scipy.stats import wishart as ss_wishart
99
from scipy.stats import multivariate_t as ss_multivariate_t
1010
from scipy.stats import dirichlet as ss_dirichlet
11-
from scipy.special import gammaln, digamma, xlogy
11+
from scipy.special import gammaln, digamma, xlogy, logsumexp
1212
import matplotlib.pyplot as plt
1313

1414
from .. import base
@@ -499,15 +499,28 @@ def __init__(
499499
self.hn_w_mats = np.empty([self.num_classes,self.degree,self.degree])
500500
self.hn_w_mats_inv = np.empty([self.num_classes,self.degree,self.degree])
501501

502-
# statistics
502+
self.ln_rho = None
503503
self.r_vecs = None
504-
self.x_bar_vecs = np.empty([self.num_classes,self.degree])
505-
self.ns = np.empty(self.num_classes)
506-
self.s_mats = np.empty([self.num_classes,self.degree,self.degree])
507504
self.e_lambda_mats = np.empty([self.num_classes,self.degree,self.degree])
508505
self.e_ln_lambda_dets = np.empty(self.num_classes)
506+
self.ln_b_hn_w_nus = np.empty(self.num_classes)
509507
self.e_ln_pi_vec = np.empty(self.num_classes)
510508

509+
# statistics
510+
self.x_bar_vecs = np.empty([self.num_classes,self.degree])
511+
self.ns = np.empty(self.num_classes)
512+
self.s_mats = np.empty([self.num_classes,self.degree,self.degree])
513+
514+
# variational lower bound
515+
self.vl = 0.0
516+
self.vl_p_x = 0.0
517+
self.vl_p_z = 0.0
518+
self.vl_p_pi = 0.0
519+
self.vl_p_mu_lambda = 0.0
520+
self.vl_q_z = 0.0
521+
self.vl_q_pi = 0.0
522+
self.vl_q_mu_lambda = 0.0
523+
511524
# p_params
512525
self.p_pi_vec = np.empty([self.num_classes])
513526
self.p_mu_vecs = np.empty([self.num_classes,self.degree])
@@ -695,6 +708,21 @@ def reset_hn_params(self):
695708
self.hn_w_mats[:] = self.h0_w_mats
696709
self.hn_w_mats_inv = np.linalg.inv(self.hn_w_mats)
697710

711+
self.e_lambda_mats[:] = self.hn_nus[:,np.newaxis,np.newaxis] * self.hn_w_mats
712+
self.e_ln_lambda_dets[:] = (
713+
np.sum(digamma((self.hn_nus[:,np.newaxis]-np.arange(self.degree)) / 2.0),axis=1)
714+
+ self.degree*np.log(2.0)
715+
- np.linalg.slogdet(self.hn_w_mats_inv)[1]
716+
)
717+
self.e_ln_pi_vec[:] = digamma(self.hn_alpha_vec) - digamma(self.hn_alpha_vec.sum())
718+
self.ln_b_hn_w_nus[:] = (
719+
self.hn_nus*np.linalg.slogdet(self.hn_w_mats_inv)[1]
720+
- self.hn_nus*self.degree*np.log(2.0)
721+
- self.degree*(self.degree-1)/2.0*np.log(np.pi)
722+
- np.sum(gammaln((self.hn_nus[:,np.newaxis]-np.arange(self.degree)) / 2.0),
723+
axis=1) * 2.0
724+
) / 2.0
725+
698726
self.calc_pred_dist()
699727

700728
def overwrite_h0_params(self):
@@ -713,22 +741,8 @@ def overwrite_h0_params(self):
713741
self.calc_pred_dist()
714742

715743
def calc_vl(self):
716-
self.e_lambda_mats = self.hn_nus[:,np.newaxis,np.newaxis] * self.hn_w_mats
717-
self.e_ln_lambda_dets = (np.sum(digamma((self.hn_nus[:,np.newaxis]-np.arange(self.degree)) / 2.0),axis=1)
718-
+ self.degree*np.log(2.0)
719-
- np.linalg.slogdet(self.hn_w_mats_inv)[1])
720-
self.e_ln_pi_vec = digamma(self.hn_alpha_vec) - digamma(self.hn_alpha_vec.sum())
721-
722-
# tentative
723-
self.ns = np.ones(self.num_classes) * 10
724-
self.s_mats = np.tile(np.identity(self.degree),[self.num_classes,1,1]) * 5
725-
self.r_vecs = np.ones([20,self.degree])/self.degree
726-
self.x_bar_vecs = np.ones([self.num_classes,self.degree])
727-
728-
vl = 0.0
729-
730744
# E[ln p(X|Z,mu,Lambda)]
731-
vl += np.sum(
745+
self.vl_p_x = np.sum(
732746
self.ns
733747
* (self.e_ln_lambda_dets - self.degree / self.hn_kappas
734748
- (self.s_mats * self.e_lambda_mats).sum(axis=(1,2))
@@ -741,63 +755,173 @@ def calc_vl(self):
741755
) / 2.0
742756

743757
# E[ln p(Z|pi)]
744-
vl += (self.ns * self.e_ln_pi_vec).sum()
758+
self.vl_p_z = (self.ns * self.e_ln_pi_vec).sum()
745759

746760
# E[ln p(pi)]
747-
vl += self.LN_C_H0_ALPHA + ((self.h0_alpha_vec - 1) * self.e_ln_pi_vec).sum()
761+
self.vl_p_pi = self.LN_C_H0_ALPHA + ((self.h0_alpha_vec - 1) * self.e_ln_pi_vec).sum()
748762

749763
# E[ln p(mu,Lambda)]
750-
vl += np.sum(
751-
self.degree * (np.log(self.h0_kappas) - np.log(2*np.pi) - self.h0_kappas/self.hn_kappas)
752-
- ((self.hn_m_vecs - self.h0_m_vecs)[:,np.newaxis,:]
753-
@ self.e_lambda_mats
754-
@ (self.hn_m_vecs - self.h0_m_vecs)[:,:,np.newaxis])[:,0,0]
764+
self.vl_p_mu_lambda = np.sum(
765+
self.degree * (np.log(self.h0_kappas) - np.log(2*np.pi)
766+
- self.h0_kappas/self.hn_kappas)
767+
- self.h0_kappas * ((self.hn_m_vecs - self.h0_m_vecs)[:,np.newaxis,:]
768+
@ self.e_lambda_mats
769+
@ (self.hn_m_vecs - self.h0_m_vecs)[:,:,np.newaxis])[:,0,0]
755770
+ 2.0 * self.LN_B_H0_W_NUS
756-
+ (self.h0_nus - self.degree) / 2.0 * self.e_ln_lambda_dets
757-
- np.sum(self.h0_w_mats_inv * self.hn_w_mats,axis=(1,2))
771+
+ (self.h0_nus - self.degree) * self.e_ln_lambda_dets
772+
- np.sum(self.h0_w_mats_inv * self.e_lambda_mats,axis=(1,2))
758773
) / 2.0
759774

760775
# E[ln q(Z|pi)]
761-
vl -= np.sum(xlogy(self.r_vecs,self.r_vecs))
776+
self.vl_q_z = -np.sum(xlogy(self.r_vecs,self.r_vecs))
762777

763778
# E[ln q(pi)]
764-
vl += ss_dirichlet.entropy(self.hn_alpha_vec)
779+
self.vl_q_pi = ss_dirichlet.entropy(self.hn_alpha_vec)
765780

766781
# E[ln q(mu,Lambda)]
767-
vl += np.sum(
782+
self.vl_q_mu_lambda = np.sum(
768783
+ self.degree * (1.0 + np.log(2.0*np.pi) - np.log(self.hn_kappas))
769-
- self.LN_B_H0_W_NUS * 2.0
784+
- self.ln_b_hn_w_nus * 2.0
770785
- (self.hn_nus-self.degree)*self.e_ln_lambda_dets
771786
+ self.hn_nus * self.degree
772787
) / 2.0
773788

774-
return vl
789+
# print(self.vl_p_x,
790+
# self.vl_p_z,
791+
# self.vl_p_pi,
792+
# self.vl_p_mu_lambda,
793+
# self.vl_q_z,
794+
# self.vl_q_pi,
795+
# self.vl_q_mu_lambda,
796+
# )
797+
798+
self.vl = (self.vl_p_x
799+
+ self.vl_p_z
800+
+ self.vl_p_pi
801+
+ self.vl_p_mu_lambda
802+
+ self.vl_q_z
803+
+ self.vl_q_pi
804+
+ self.vl_q_mu_lambda)
805+
806+
def _calc_statistics(self,x):
807+
self.ns[:] = self.r_vecs.sum(axis=0)
808+
self.x_bar_vecs[:] = (self.r_vecs[:,:,np.newaxis] * x[:,np.newaxis,:]).sum(axis=0) / self.ns[:,np.newaxis]
809+
self.s_mats[:] = np.sum(self.r_vecs[:,:,np.newaxis,np.newaxis]
810+
* ((x[:,np.newaxis,:] - self.x_bar_vecs)[:,:,:,np.newaxis]
811+
@ (x[:,np.newaxis,:] - self.x_bar_vecs)[:,:,np.newaxis,:]),
812+
axis=0) / self.ns[:,np.newaxis,np.newaxis]
813+
814+
def _init_q_z(self):
815+
self.r_vecs[:] = self.rng.dirichlet(np.ones(self.num_classes),self.r_vecs.shape[0])
816+
817+
def _update_q_pi(self):
818+
self.hn_alpha_vec[:] = self.h0_alpha_vec + self.ns
819+
self.e_ln_pi_vec[:] = digamma(self.hn_alpha_vec) - digamma(self.hn_alpha_vec.sum())
820+
821+
def _update_q_mu_lambda(self):
822+
self.hn_kappas[:] = self.h0_kappas + self.ns
823+
self.hn_m_vecs[:] = (self.h0_kappas[:,np.newaxis] * self.h0_m_vecs
824+
+ self.ns[:,np.newaxis] * self.x_bar_vecs) / self.hn_kappas[:,np.newaxis]
825+
self.hn_nus[:] = self.h0_nus + self.ns
826+
self.hn_w_mats_inv[:] = (self.h0_w_mats_inv
827+
+ self.ns[:,np.newaxis,np.newaxis] * self.s_mats
828+
+ (self.h0_kappas * self.ns / self.hn_kappas)[:,np.newaxis,np.newaxis]
829+
* ((self.x_bar_vecs - self.h0_m_vecs)[:,:,np.newaxis]
830+
@ (self.x_bar_vecs - self.h0_m_vecs)[:,np.newaxis,:])
831+
)
832+
self.hn_w_mats[:] = np.linalg.inv(self.hn_w_mats_inv)
833+
self.e_lambda_mats[:] = self.hn_nus[:,np.newaxis,np.newaxis] * self.hn_w_mats
834+
self.e_ln_lambda_dets[:] = (np.sum(digamma((self.hn_nus[:,np.newaxis]-np.arange(self.degree)) / 2.0),axis=1)
835+
+ self.degree*np.log(2.0)
836+
- np.linalg.slogdet(self.hn_w_mats_inv)[1])
837+
self.ln_b_hn_w_nus[:] = (
838+
self.hn_nus*np.linalg.slogdet(self.hn_w_mats_inv)[1]
839+
- self.hn_nus*self.degree*np.log(2.0)
840+
- self.degree*(self.degree-1)/2.0*np.log(np.pi)
841+
- np.sum(gammaln((self.hn_nus[:,np.newaxis]-np.arange(self.degree)) / 2.0),
842+
axis=1) * 2.0
843+
) / 2.0
775844

776-
def update_posterior(self,x):
777-
pass
778-
# """Update the hyperparameters of the posterior distribution using traning data.
845+
def _update_q_z(self,x):
846+
self.ln_rho[:] = (self.e_ln_pi_vec
847+
+ (self.e_ln_lambda_dets
848+
- self.degree * np.log(2*np.pi)
849+
- self.degree / self.hn_kappas
850+
- ((x[:,np.newaxis,:]-self.hn_m_vecs)[:,:,np.newaxis,:]
851+
@ self.e_lambda_mats
852+
@ (x[:,np.newaxis,:]-self.hn_m_vecs)[:,:,:,np.newaxis]
853+
)[:,:,0,0]
854+
) / 2.0
855+
)
856+
self.r_vecs[:] = np.exp(self.ln_rho - logsumexp(self.ln_rho,axis=1,keepdims=True))
857+
# self.r_vecs[:] = np.exp(self.ln_rho - self.ln_rho.max(axis=1,keepdims=True))
858+
# self.r_vecs[:] /= self.r_vecs.sum(axis=1,keepdims=True)
859+
860+
def update_posterior(self,x,max_itr=100,num_init=10,tolerance=1.0E-8):
861+
"""Update the hyperparameters of the posterior distribution using traning data.
779862
780-
# Parameters
781-
# ----------
782-
# x : numpy.ndarray
783-
# All the elements must be real number.
784-
# """
785-
# _check.float_vecs(x,'x',DataFormatError)
786-
# if self.degree > 1 and x.shape[-1] != self.degree:
787-
# raise(DataFormatError(f"x.shape[-1] must be degree:{self.degree}"))
788-
# x = x.reshape(-1,self.degree)
789-
790-
# n = x.shape[0]
791-
# x_bar = x.sum(axis=0)/n
792-
793-
# self.hn_w_mat_inv[:] = (self.hn_w_mat_inv + (x-x_bar).T @ (x-x_bar)
794-
# + (x_bar - self.hn_m_vec)[:,np.newaxis] @ (x_bar - self.hn_m_vec)[np.newaxis,:]
795-
# * self.hn_kappa * n / (self.hn_kappa + n))
796-
# self.hn_m_vec[:] = (self.hn_kappa*self.hn_m_vec + n*x_bar) / (self.hn_kappa+n)
797-
# self.hn_kappa += n
798-
# self.hn_nu += n
799-
800-
# self.hn_w_mat[:] = np.linalg.inv(self.hn_w_mat_inv)
863+
Parameters
864+
----------
865+
x : numpy.ndarray
866+
All the elements must be real number.
867+
max_itr : int, optional
868+
maximum number of iterations, by default 100
869+
num_init : int, optional
870+
number of initializations, by default 10
871+
tolerance : float, optional
872+
convergence croterion of variational lower bound, by default 1.0E-8
873+
"""
874+
_check.float_vecs(x,'x',DataFormatError)
875+
if self.degree > 1 and x.shape[-1] != self.degree:
876+
raise(DataFormatError(
877+
"x.shape[-1] must be self.degree: "
878+
+ f"x.shape[-1]={x.shape[-1]}, self.degree={self.degree}"))
879+
x = x.reshape(-1,self.degree)
880+
self.ln_rho = np.empty([x.shape[0],self.num_classes])
881+
self.r_vecs = np.empty([x.shape[0],self.num_classes])
882+
883+
tmp_vl = 0.0
884+
tmp_alpha_vec = np.copy(self.hn_alpha_vec)
885+
tmp_m_vecs = np.copy(self.hn_m_vecs)
886+
tmp_kappas = np.copy(self.hn_kappas)
887+
tmp_nus = np.copy(self.hn_nus)
888+
tmp_w_mats = np.copy(self.hn_w_mats)
889+
tmp_w_mats_inv = np.copy(self.hn_w_mats_inv)
890+
891+
for i in range(num_init):
892+
self._init_q_z()
893+
self._calc_statistics(x)
894+
self.calc_vl()
895+
print(f'\r{i}. VL: {self.vl}',end='')
896+
for t in range(max_itr):
897+
vl_before = self.vl
898+
899+
self._update_q_mu_lambda()
900+
self._update_q_pi()
901+
self._update_q_z(x)
902+
self._calc_statistics(x)
903+
self.calc_vl()
904+
print(f'\r{i}. VL: {self.vl}',end='')
905+
if np.abs((self.vl-vl_before)/vl_before) < tolerance:
906+
break
907+
if i==0 or self.vl > tmp_vl:
908+
print('*')
909+
tmp_vl = self.vl
910+
tmp_alpha_vec[:] = self.hn_alpha_vec
911+
tmp_m_vecs[:] = self.hn_m_vecs
912+
tmp_kappas[:] = self.hn_kappas
913+
tmp_nus[:] = self.hn_nus
914+
tmp_w_mats[:] = self.hn_w_mats
915+
tmp_w_mats_inv[:] = self.hn_w_mats_inv
916+
else:
917+
print('')
918+
919+
self.hn_alpha_vec[:] = tmp_alpha_vec
920+
self.hn_m_vecs[:] = tmp_m_vecs
921+
self.hn_kappas[:] = tmp_kappas
922+
self.hn_nus[:] = tmp_nus
923+
self.hn_w_mats[:] = tmp_w_mats
924+
self.hn_w_mats_inv[:] = tmp_w_mats_inv
801925

802926
def estimate_params(self,loss="squared"):
803927
pass

bayesml/gaussianmixture/gaussianmixture.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ $$
7979
\begin{align}
8080
N_k^{(t)} &= \sum_{i=1}^n r_{i,k}^{(t)}, \\
8181
\bar{\boldsymbol{x}}_k^{(t)} &= \frac{1}{N_k^{(t)}} \sum_{i=1}^n r_{i,k}^{(t)} \boldsymbol{x}_i, \\
82-
\boldsymbol{m}_{n,k}^{(t+1)} &= \frac{\kappa_0\boldsymbol{\mu}_0 + N_k^{(t)} \bar{\boldsymbol{x}}_k^{(t)}}{\kappa_0 + N_k^{(t)}}, \\
82+
\boldsymbol{m}_{n,k}^{(t+1)} &= \frac{\kappa_0\boldsymbol{m}_0 + N_k^{(t)} \bar{\boldsymbol{x}}_k^{(t)}}{\kappa_0 + N_k^{(t)}}, \\
8383
\kappa_{n,k}^{(t+1)} &= \kappa_0 + N_k^{(t)}, \\
84-
(\boldsymbol{W}_{n,k}^{(t+1)})^{-1} &= \boldsymbol{W}_0^{-1} + \sum_{i=1}^{n} r_{i,k}^{(t)} (\boldsymbol{x}_i-\bar{\boldsymbol{x}}_k^{(t)})(\boldsymbol{x}_i-\bar{\boldsymbol{x}}_k^{(t)})^\top + \frac{\kappa_0 N_k^{(t)}}{\kappa_0 + N_k^{(t)}}(\bar{\boldsymbol{x}}_k^{(t)}-\boldsymbol{\mu}_0)(\bar{\boldsymbol{x}}_k^{(t)}-\boldsymbol{\mu}_0)^\top, \\
84+
(\boldsymbol{W}_{n,k}^{(t+1)})^{-1} &= \boldsymbol{W}_0^{-1} + \sum_{i=1}^{n} r_{i,k}^{(t)} (\boldsymbol{x}_i-\bar{\boldsymbol{x}}_k^{(t)})(\boldsymbol{x}_i-\bar{\boldsymbol{x}}_k^{(t)})^\top + \frac{\kappa_0 N_k^{(t)}}{\kappa_0 + N_k^{(t)}}(\bar{\boldsymbol{x}}_k^{(t)}-\boldsymbol{m}_0)(\bar{\boldsymbol{x}}_k^{(t)}-\boldsymbol{m}_0)^\top, \\
8585
\nu_{n,k}^{(t+1)} &= \nu_0 + N_k^{(t)},\\
8686
\alpha_{n,k}^{(t+1)} &= \alpha_{0,k} + N_k^{(t)}, \\
8787
\ln \rho_{i,k}^{(t+1)} &= \psi (\alpha_{n,k}^{(t+1)}) - \psi ( {\textstyle \sum_{k=1}^K \alpha_{n,k}^{(t+1)}} ) \notag \\

bayesml/gaussianmixture/test.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from bayesml import gaussianmixture
22
import numpy as np
33

4-
model = gaussianmixture.LearnModel(num_classes=3, degree=2, h0_w_mats=np.identity(2)*2)
5-
print(model.calc_vl())
4+
model = gaussianmixture.LearnModel(num_classes=5, degree=3)
5+
6+
x = np.random.rand(10,3)
7+
8+
model.update_posterior(x,num_init=3)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# __init__関数の仕様変更についての相談とお願い
2+
3+
## 背景
4+
5+
* ベイズ統計学的には事前分布と事後分布のハイパーパラメータは同じ形の方がうれしい.
6+
* 事前分布の意味
7+
* 逐次更新
8+
* 混合正規分布の場合
9+
* 事前分布のハイパーパラメータを一般化すると,インスタンス生成時のハイパーパラメータ指定方法が多様になりすぎる.
10+
* 個々の混合要素のハイパーパラメータを個別に指定する
11+
* 全混合要素のハイパーパラメータを共通にする(ブロードキャストできると嬉しい)
12+
* 次元がそろわない入力を受け付けなければならない.
13+
* __init__の実装の手間が大きい.
14+
15+
## 対応案
16+
17+
モデルのパラメータを3つに分ける.
18+
19+
* constants
20+
* 事後分布更新などに関与しないが,paramsやh_paramsの行列サイズを定める既知定数.degree, num_classesなど.
21+
* インスタンス生成時,必ず手動で与える形式にする(デフォルト値無し)
22+
* c_valname
23+
* params
24+
* サンプルを生成する確率分布のパラメータ
25+
* h_params, h0_params, hn_params
26+
* paramsを生成する確率分布(事前分布)のパラメータ
27+
* これらは同じ型の行列にする
28+
* p_params
29+
* 予測分布のパラメータ
30+
31+
次元確認は利用者が手動で与えたconstantsに対してparamsやh_paramsが整合しているかのみをチェックする.その際,ブロードキャストもうまく利用する.

0 commit comments

Comments
 (0)