Skip to content

Commit 62279de

Browse files
committed
modify ldsEm to use ldsPca as initialization
1 parent ca599be commit 62279de

File tree

2 files changed

+33
-29
lines changed

2 files changed

+33
-29
lines changed

chapter13/LDS/ldsEm.m

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
model = init(X,m);
1818
end
1919
tol = 1e-4;
20-
maxIter = 1000;
20+
maxIter = 2000;
2121
llh = -inf(1,maxIter);
2222
for iter = 2:maxIter
2323
% E-step
@@ -29,20 +29,23 @@
2929
llh = llh(2:iter);
3030

3131
function model = init(X, k)
32-
d = size(X,1);
33-
model.mu0 = randn(k,1);
34-
model.P0 = iwishrnd(eye(k),k);
35-
model.A = randn(k,k);
36-
model.G = iwishrnd(eye(k),k);
37-
model.C = randn(d,k);
38-
model.S = iwishrnd(eye(d),d);
39-
% [A,C,Z] = ldsPca(X,k,3*k);
40-
% model.mu0 = Z(:,1);
41-
% model.P0 = ;
42-
% model.A = A;
43-
% model.C = C;
44-
% model.G = ;
45-
% model.S = ;
32+
% d = size(X,1);
33+
% model.mu0 = randn(k,1);
34+
% model.P0 = iwishrnd(eye(k),k);
35+
% model.A = randn(k,k);
36+
% model.G = iwishrnd(eye(k),k);
37+
% model.C = randn(d,k);
38+
% model.S = iwishrnd(eye(d),d);
39+
[A,C,Z] = ldsPca(X,k,3*k);
40+
model.mu0 = Z(:,1);
41+
E = Z(:,1:end-1)-Z(:,2:end);
42+
model.P0 = (dot(E(:),E(:))/(k*size(E,2)))*eye(k);
43+
model.A = A;
44+
E = A*Z(:,1:end-1)-Z(:,2:end);
45+
model.G = E*E'/size(E,2);
46+
model.C = C;
47+
E = C*Z-X(:,1:size(Z,2));
48+
model.S = E*E'/size(E,2);
4649

4750
function model = maximization(X ,nu, U, Ezz, Ezy)
4851
n = size(X,2);

demo/ch13/lds_demo.m

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
close all;
2-
%% Parameter
2+
% Parameter
33
clear;
44
d = 2;
5-
k = 2;
6-
n = 50;
5+
k = 3;
6+
n = 100;
77

8-
A = [1,1;
9-
0 1];
8+
A = [1,0,1;
9+
0 1,0;
10+
0,0,1];
1011
G = eye(k)*1e-3;
1112

12-
C = [1 0;
13-
0 1];
13+
C = [1,0,0;
14+
0 1,0];
1415
S = eye(d)*1e-1;
1516

16-
mu0 = [0; 0];
17+
mu0 = [0;0;0];
1718
P0 = eye(k);
1819

1920
model.A = A;
@@ -54,9 +55,9 @@
5455
axis equal
5556
hold off
5657
%% LDS Subspace
57-
[A,C,z] = ldsPca(x,k,3*k);
58-
y = C*z;
59-
t = size(z,2);
58+
[A,C,nu] = ldsPca(x,k,3*k);
59+
y = C*nu;
60+
t = size(y,2);
6061
figure;
6162
hold on
6263
plot(x(1,1:t), x(2,1:t), 'ro');
@@ -66,9 +67,9 @@
6667
axis equal
6768
hold off
6869
%% LDS EM
69-
[model, llh] = ldsEm(x,k);
70-
nu = kalmanSmoother(model,x);
71-
y = model.C*nu;
70+
[tmodel, llh] = ldsEm(x,k);
71+
nu = kalmanSmoother(tmodel,x);
72+
y = tmodel.C*nu;
7273
figure
7374
hold on
7475
plot(x(1,:), x(2,:), 'ro');

0 commit comments

Comments
 (0)