Skip to content

Commit 2fb3339

Browse files
committed
rewrite mlp and backpropagation for regression
1 parent fcf7e0d commit 2fb3339

File tree

4 files changed

+82
-52
lines changed

4 files changed

+82
-52
lines changed

chapter05/mlp.m

Lines changed: 0 additions & 39 deletions
This file was deleted.

chapter05/mlpReg.m

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
function [model, L] = mlpReg(X,Y,k,lambda)
2+
% Train a multilayer perceptron neural network
3+
% Input:
4+
% X: d x n data matrix
5+
% Y: p x n response matrix
6+
% k: T x 1 vector to specify number of hidden nodes in each layer
7+
% lambda: regularization parameter
8+
% Ouput:
9+
% model: model structure
10+
% L: loss
11+
% Written by Mo Chen (sth4nth@gmail.com).
12+
if nargin < 4
13+
lambda = 1e-2;
14+
end
15+
eta = 1e-3;
16+
maxiter = 50000;
17+
L = inf(1,maxiter);
18+
19+
k = [size(X,1);k(:);size(Y,1)];
20+
T = numel(k)-1;
21+
W = cell(T,1);
22+
b = cell(T,1);
23+
for t = 1:T
24+
W{t} = randn(k(t),k(t+1));
25+
b{t} = randn(k(t+1),1);
26+
end
27+
R = cell(T,1);
28+
Z = cell(T+1,1);
29+
Z{1} = X;
30+
for iter = 2:maxiter
31+
% forward
32+
for t = 1:T-1
33+
Z{t+1} = tanh(W{t}'*Z{t}+b{t});
34+
end
35+
Z{T+1} = W{T}'*Z{T}+b{T};
36+
37+
% loss
38+
E = Z{T+1}-Y;
39+
Wn = cellfun(@(x) dot(x(:),x(:)),W); % |W|^2
40+
L(iter) = dot(E(:),E(:))+lambda*sum(Wn);
41+
42+
% backward
43+
R{T} = E; % delta
44+
for t = T-1:-1:1
45+
df = 1-Z{t+1}.^2; % h'(a)
46+
R{t} = df.*(W{t+1}*R{t+1}); % delta
47+
end
48+
49+
% gradient descent
50+
for t=1:T
51+
dW = Z{t}*R{t}'+lambda*W{t};
52+
db = sum(R{t},2);
53+
W{t} = W{t}-eta*dW;
54+
b{t} = b{t}-eta*db;
55+
end
56+
end
57+
L = L(1,2:iter);
58+
model.W = W;
59+
model.b = b;
Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function Y = mlpPred(model, X)
1+
function Y = mlpRegPred(model, X)
22
% Multilayer perceptron prediction
33
% Input:
44
% model: model structure
@@ -7,7 +7,11 @@
77
% Y: p x n response matrix
88
% Written by Mo Chen (sth4nth@gmail.com).
99
W = model.W;
10-
Y = X;
11-
for l = 1:length(W)
12-
Y = sigmoid(W{l}'*Y);
13-
end
10+
b = model.b;
11+
T = length(W);
12+
Z = cell(T+1,1);
13+
Z{1} = X;
14+
for t = 1:T-1
15+
Z{t+1} = tanh(W{t}'*Z{t}+b{t});
16+
end
17+
Y = W{T}'*Z{T}+b{T};

demo/ch05/mlp_demo.m

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
clear; close all;
2-
h = [4,5];
3-
X = [0 0 1 1;0 1 0 1];
4-
T = [0 1 1 0];
5-
[model,mse] = mlp(X,T,h);
6-
plot(mse);
7-
disp(['T = [' num2str(T) ']']);
8-
Y = mlpPred(model,X);
9-
disp(['Y = [' num2str(Y) ']']);
2+
n = 200;
3+
x = linspace(0,2*pi,n);
4+
y = sin(x);
5+
6+
k = [3,4]; % two hidden layers with 3 and 4 hidden nodes
7+
lambda = 1e-2;
8+
[model, L] = mlpReg(x,y,k);
9+
t = mlpRegPred(model,x);
10+
plot(L);
11+
figure;
12+
hold on
13+
plot(x,y,'.');
14+
plot(x,t);
15+
hold off

0 commit comments

Comments
 (0)