Skip to content

Commit f335f71

Browse files
author
zhangxiong
committed
bug fix. reflect bug.
1 parent a674965 commit f335f71

21 files changed

+1270
-379
lines changed

src/Discriminator.py

Lines changed: 56 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
11

2+
'''
3+
file: Discriminator.py
4+
5+
date: 2017_04_29
6+
author: zhangxiong(1025679612@qq.com)
7+
'''
8+
29
from LinearModel import LinearModel
310
import config
411
import util
512
import torch
613
import numpy as np
714
import torch.nn as nn
815
from config import args
9-
import torch.nn.functional as F
1016

17+
'''
18+
shape discriminator is used for shape discriminator
19+
the inputs if N x 10
20+
'''
1121
class ShapeDiscriminator(LinearModel):
1222
def __init__(self, fc_layers, use_dropout, drop_prob, use_ac_func):
1323
if fc_layers[-1] != 1:
@@ -19,20 +29,36 @@ def __init__(self, fc_layers, use_dropout, drop_prob, use_ac_func):
1929
def forward(self, inputs):
2030
return self.fc_blocks(inputs)
2131

22-
class PoseDiscriminator(LinearModel):
23-
def __init__(self, fc_layers, use_dropout, drop_prob, use_ac_func):
24-
if fc_layers[-1] != 1:
25-
msg = 'the neuron count of the last layer must be 1, but got {}'.format(fc_layers[-1])
32+
class PoseDiscriminator(nn.Module):
33+
def __init__(self, channels):
34+
super(PoseDiscriminator, self).__init__()
35+
36+
if channels[-1] != 1:
37+
msg = 'the neuron count of the last layer must be 1, but got {}'.format(channels[-1])
2638
sys.exit(msg)
2739

28-
super(PoseDiscriminator, self).__init__(fc_layers, use_dropout, drop_prob, use_ac_func)
40+
self.conv_blocks = nn.Sequential()
41+
l = len(channels)
42+
for idx in range(l - 2):
43+
self.conv_blocks.add_module(
44+
name = 'conv_{}'.format(idx),
45+
module = nn.Conv2d(in_channels = channels[idx], out_channels = channels[idx + 1], kernel_size = 1, stride = 1)
46+
)
2947

48+
self.fc_layer = nn.ModuleList()
49+
for idx in range(23):
50+
self.fc_layer.append(nn.Linear(in_features = channels[l - 2], out_features = 1))
51+
52+
# N x 23 x 9
3053
def forward(self, inputs):
31-
'''
32-
x = self.fc_blocks(inputs)
33-
return [x, self.last_block(x)]
34-
'''
35-
return self.fc_blocks(inputs)
54+
batch_size = inputs.shape[0]
55+
inputs = inputs.transpose(1, 2).unsqueeze(2) # to N x 9 x 1 x 23
56+
internal_outputs = self.conv_blocks(inputs) # to N x c x 1 x 23
57+
o = []
58+
for idx in range(23):
59+
o.append(self.fc_layer[idx](internal_outputs[:,:,0,idx]))
60+
61+
return torch.cat(o, 1), internal_outputs
3662

3763
class FullPoseDiscriminator(LinearModel):
3864
def __init__(self, fc_layers, use_dropout, drop_prob, use_ac_func):
@@ -64,21 +90,16 @@ def _create_sub_modules(self):
6490
'''
6591
create theta discriminator for 23 joint
6692
'''
67-
fc_layers = [9, 32, 32, 1]
68-
use_dropout = [False, False, False]
69-
drop_prob = [0.5, 0.5, 0.5]
70-
use_ac_func = [True, True, False]
71-
self.pose_discriminators = nn.ModuleList()
72-
for _ in range(self.joint_count - 1):
73-
self.pose_discriminators.append(PoseDiscriminator(fc_layers, use_dropout, drop_prob, use_ac_func))
93+
94+
self.pose_discriminator = PoseDiscriminator([9, 32, 32, 1])
7495

7596
'''
7697
create full pose discriminator for total 23 joints
7798
'''
78-
fc_layers = [(self.joint_count - 1) * 9, 1024, 1024, 1024, 1]
79-
use_dropout = [False, False, False, False]
80-
drop_prob = [0.5, 0.5, 0.5, 0.5]
81-
use_ac_func = [True, True, True, False]
99+
fc_layers = [23 * 32, 1024, 1024, 1]
100+
use_dropout = [False, False, False]
101+
drop_prob = [0.5, 0.5, 0.5]
102+
use_ac_func = [True, True, False]
82103
self.full_pose_discriminator = FullPoseDiscriminator(fc_layers, use_dropout, drop_prob, use_ac_func)
83104

84105
'''
@@ -92,67 +113,22 @@ def _create_sub_modules(self):
92113

93114
print('finished create the discriminator modules...')
94115

95-
'''
96-
purpose:
97-
calc mean shape discriminator value
98-
inputs:
99-
real_shape N x 10
100-
fake_shape n x 10
101-
return:
102-
shape discriminator output value
103-
'''
104-
def calc_shape_disc_value(self, real_shape, fake_shape):
105-
shapes = torch.cat([real_shape, fake_shape], dim = 0)
106-
return self.shape_discriminator(shapes)
107-
108-
'''
109-
inputs:
110-
real_pose N x 24 x 3
111-
fake_pose n x 24 x 3
112-
return:
113-
pose discriminator output value
114-
'''
115-
def calc_pose_disc_value(self, real_pose, fake_pose):
116-
real_pose = util.batch_rodrigues(real_pose.view(-1, 3)).view(-1, 24, 9)
117-
fake_pose = util.batch_rodrigues(fake_pose.view(-1, 3)).view(-1, 24, 9)
118-
poses = torch.cat((real_pose[:, 1:, :], fake_pose[:, 1:, :]), dim = 0)
119-
full_pose_dis_value = self.full_pose_discriminator(poses.view(-1, 23 * 9))
120-
poses = torch.transpose(poses, 0, 1)
121-
theta_disc_values = []
122-
for _ in range(23):
123-
theta_disc_values.append(
124-
self.pose_discriminators[_](poses[_, :, :])
125-
)
126-
pose_dis_value = torch.cat(theta_disc_values, dim = 1)
127-
return torch.cat([pose_dis_value, full_pose_dis_value], dim = 1)
128116

129117
'''
130-
inputs:
131-
real_thetas N x 85
132-
fake_thetas N x 85
133-
return
134-
pose & full pose & shape disc value N x (23 + 1 + 1)
118+
inputs is N x 85(3 + 72 + 10)
135119
'''
136-
def calc_thetas_disc_value(self, real_thetas, fake_thetas):
137-
real_poses, fake_poses = real_thetas[:, 3:75], fake_thetas[:, 3:75]
138-
real_shapes, fake_shapes = real_thetas[:, 75:], fake_thetas[:, 75:]
139-
pose_disc_value = self.calc_pose_disc_value(real_poses.contiguous(), fake_poses.contiguous())
140-
shape_disc_value = self.calc_shape_disc_value(real_shapes.contiguous(), fake_shapes.contiguous())
141-
return torch.cat([pose_disc_value, shape_disc_value], dim = 1)
142-
143-
def forward(self, real_thetas, fake_thetas):
144-
if config.args.normalize_disc:
145-
return F.sigmoid(self.calc_thetas_disc_value(real_thetas, fake_thetas))
146-
else:
147-
return self.calc_thetas_disc_value(real_thetas, fake_thetas)
148-
120+
def forward(self, thetas):
121+
batch_size = thetas.shape[0]
122+
cams, poses, shapes = thetas[:, :3], thetas[:, 3:75], thetas[:, 75:]
123+
shape_disc_value = self.shape_discriminator(shapes)
124+
rotate_matrixs = util.batch_rodrigues(poses.contiguous().view(-1, 3)).view(-1, 24, 9)[:, 1:, :]
125+
pose_disc_value, pose_inter_disc_value = self.pose_discriminator(rotate_matrixs)
126+
full_pose_disc_value = self.full_pose_discriminator(pose_inter_disc_value.contiguous().view(batch_size, -1))
127+
return torch.cat((pose_disc_value, full_pose_disc_value, shape_disc_value), 1)
149128

150129
if __name__ == '__main__':
151130
device = torch.device('cuda')
152-
net = Discriminator().to(device)
153-
real = torch.zeros((100, 85)).float().to(device)
154-
fake = torch.ones((200, 85)).float().to(device)
155-
156-
dis_v = net(real, fake)
157-
print(dis_v.device)
158-
print(dis_v.shape)
131+
net = Discriminator()
132+
inputs = torch.ones((100, 85))
133+
disc_value = net(inputs)
134+
print(net)

src/HourGlass.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
11

2+
'''
3+
file: hourglass.py
4+
5+
date: 2018_05_12
6+
author: zhangxiong(1025679612@qq.com)
7+
'''
8+
29
from __future__ import print_function
310
import numpy as np
411
import torch
@@ -208,4 +215,4 @@ def _create_hourglass_net():
208215
nChannels = 256,
209216
nJointCount = 1,
210217
bUseBn = True,
211-
)
218+
)

src/LinearModel.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
11

2+
3+
'''
4+
file: LinearModel.py
5+
6+
date: 2018_04_29
7+
author: zhangxiong(1025679612@qq.com)
8+
'''
9+
210
import torch.nn as nn
311
import numpy as np
412
import sys
@@ -93,4 +101,4 @@ def forward(self, inputs):
93101
net = LinearModel(fc_layers, use_dropout, drop_prob, use_ac_func).to(device)
94102
print(net)
95103
nx = np.zeros([2, 2048])
96-
vx = torch.from_numpy(nx).to(device)
104+
vx = torch.from_numpy(nx).to(device)

src/PRNetEncoder.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
2+
3+
'''
4+
file: PRnetEncoder.py
5+
6+
date: 2018_05_22
7+
author: zhangxiong(1025679612@qq.com)
8+
mark: the algorithm is cited from PRNet code
9+
'''
10+
11+
from __future__ import print_function
12+
import numpy as np
13+
import torch
14+
import torch.nn as nn
15+
import torch.nn.functional as F
16+
17+
class Residual(nn.Module):
18+
def __init__(self, use_bn, input_channels, out_channels, mid_channels, kernel_size = 3, padding = 1, stride = 1):
19+
super(Residual, self).__init__()
20+
self.use_bn = use_bn
21+
self.out_channels = out_channels
22+
self.input_channels = input_channels
23+
self.mid_channels = mid_channels
24+
25+
self.down_channel = nn.Conv2d(input_channels, self.mid_channels, kernel_size = 1)
26+
self.AcFunc = nn.ReLU()
27+
if use_bn:
28+
self.bn_0 = nn.BatchNorm2d(num_features = self.mid_channels)
29+
self.bn_1 = nn.BatchNorm2d(num_features = self.mid_channels)
30+
self.bn_2 = nn.BatchNorm2d(num_features = self.out_channels)
31+
32+
self.conv = nn.Conv2d(self.mid_channels, self.mid_channels, kernel_size = kernel_size, padding = padding, stride = stride)
33+
34+
self.up_channel = nn.Conv2d(self.mid_channels, out_channels, kernel_size= 1)
35+
36+
if input_channels != out_channels:
37+
self.trans = nn.Conv2d(input_channels, out_channels, kernel_size = 1)
38+
39+
def forward(self, inputs):
40+
x = self.down_channel(inputs)
41+
if self.use_bn:
42+
x = self.bn_0(x)
43+
x = self.AcFunc(x)
44+
45+
x = self.conv(x)
46+
if self.use_bn:
47+
x = self.bn_1(x)
48+
x = self.AcFunc(x)
49+
50+
x = self.up_channel(x)
51+
52+
if self.input_channels != self.out_channels:
53+
x += self.trans(inputs)
54+
else:
55+
x += inputs
56+
57+
if self.use_bn:
58+
x = self.bn_2(x)
59+
60+
return self.AcFunc(x)
61+
62+
class PRNetEncoder(nn.Module):
63+
def __init__(self):
64+
super(PRNetEncoder, self).__init__()
65+
self.conv_blocks = nn.Sequential(
66+
nn.Conv2d(in_channels = 3, out_channels = 8, kernel_size = 3, stride = 1, padding = 1), # to 256 x 256 x 8
67+
nn.Conv2d(in_channels = 8, out_channels = 16, kernel_size = 3, stride = 1, padding = 1), # to 256 x 256 x 16
68+
Residual(use_bn = True, input_channels = 16, out_channels = 32, mid_channels = 16, stride = 1, padding = 1), # to 256 x 256 x 32
69+
nn.MaxPool2d(kernel_size = 2, stride = 2), # to 128 x 128 x 32
70+
Residual(use_bn = True, input_channels = 32, out_channels = 32, mid_channels = 16, stride = 1, padding = 1), # to 128 x 128 x 32
71+
Residual(use_bn = True, input_channels = 32, out_channels = 32, mid_channels = 16, stride = 1, padding = 1), # to 128 x 128 x 32
72+
Residual(use_bn = True, input_channels = 32, out_channels = 64, mid_channels = 32, stride = 1, padding = 1), # to 128 x 128 x 64
73+
nn.MaxPool2d(kernel_size = 2, stride = 2), # to 64 x 64 x 64
74+
Residual(use_bn = True, input_channels = 64, out_channels = 64, mid_channels = 32, stride = 1, padding = 1), # to 64 x 64 x 64
75+
Residual(use_bn = True, input_channels = 64, out_channels = 64, mid_channels = 32, stride = 1, padding = 1), # to 64 x 64 x 64
76+
Residual(use_bn = True, input_channels = 64, out_channels = 128, mid_channels = 64, stride = 1, padding = 1), # to 64 x 64 x 128
77+
nn.MaxPool2d(kernel_size = 2, stride = 2), # to 32 x 32 x 128
78+
Residual(use_bn = True, input_channels = 128, out_channels = 128, mid_channels = 64, stride = 1, padding = 1), # to 32 x 32 x 128
79+
Residual(use_bn = True, input_channels = 128, out_channels = 128, mid_channels = 64, stride = 1, padding = 1), # to 32 x 32 x 128
80+
Residual(use_bn = True, input_channels = 128, out_channels = 256, mid_channels = 128, stride = 1, padding = 1), # to 32 x 32 x 256
81+
nn.MaxPool2d(kernel_size = 2, stride = 2), # to 16 x 16 x 256
82+
Residual(use_bn = True, input_channels = 256, out_channels = 256, mid_channels = 128, stride = 1, padding = 1), # to 16 x 16 x 256
83+
Residual(use_bn = True, input_channels = 256, out_channels = 256, mid_channels = 128, stride = 1, padding = 1), # to 16 x 16 x 256
84+
Residual(use_bn = True, input_channels = 256, out_channels = 512, mid_channels = 256, stride = 1, padding = 1), # to 16 x 16 x 512
85+
nn.MaxPool2d(kernel_size = 2, stride = 2), # to 8 x 8 x 512
86+
Residual(use_bn = True, input_channels = 512, out_channels = 512, mid_channels = 256, stride = 1, padding = 1), # to 8 x 8 x 512
87+
nn.MaxPool2d(kernel_size = 2, stride = 2) , # to 4 x 4 x 512
88+
Residual(use_bn = True, input_channels = 512, out_channels = 512, mid_channels = 256, stride = 1, padding = 1), # to 4 x 4 x 512
89+
nn.MaxPool2d(kernel_size = 2, stride = 2), # to 2 x 2 x 512
90+
Residual(use_bn = True, input_channels = 512, out_channels = 512, mid_channels = 256, stride = 1, padding = 1) # to 2 x 2 x 512
91+
)
92+
93+
def forward(self, inputs):
94+
return self.conv_blocks(inputs).view(-1, 2048)
95+
96+
97+
if __name__ == '__main__':
98+
net = PRNetEncoder()
99+
inputs = torch.ones(size = (10, 3, 256, 256)).float()
100+
r = net(inputs)
101+
print(r.shape)

src/Resnet.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
11

2+
'''
3+
file: Resnet.py
4+
5+
date: 2018_05_02
6+
author: zhangxiong(1025679612@qq.com)
7+
mark: copied from pytorch sourc code
8+
'''
9+
210
import torch.nn as nn
311
import torch.nn.functional as F
412
import torch

0 commit comments

Comments
 (0)