1
1
2
+ '''
3
+ file: Discriminator.py
4
+
5
+ date: 2017_04_29
6
+ author: zhangxiong(1025679612@qq.com)
7
+ '''
8
+
2
9
from LinearModel import LinearModel
3
10
import config
4
11
import util
5
12
import torch
6
13
import numpy as np
7
14
import torch .nn as nn
8
15
from config import args
9
- import torch .nn .functional as F
10
16
17
+ '''
18
+ shape discriminator is used for shape discriminator
19
+ the inputs if N x 10
20
+ '''
11
21
class ShapeDiscriminator (LinearModel ):
12
22
def __init__ (self , fc_layers , use_dropout , drop_prob , use_ac_func ):
13
23
if fc_layers [- 1 ] != 1 :
@@ -19,20 +29,36 @@ def __init__(self, fc_layers, use_dropout, drop_prob, use_ac_func):
19
29
def forward (self , inputs ):
20
30
return self .fc_blocks (inputs )
21
31
22
- class PoseDiscriminator (LinearModel ):
23
- def __init__ (self , fc_layers , use_dropout , drop_prob , use_ac_func ):
24
- if fc_layers [- 1 ] != 1 :
25
- msg = 'the neuron count of the last layer must be 1, but got {}' .format (fc_layers [- 1 ])
32
+ class PoseDiscriminator (nn .Module ):
33
+ def __init__ (self , channels ):
34
+ super (PoseDiscriminator , self ).__init__ ()
35
+
36
+ if channels [- 1 ] != 1 :
37
+ msg = 'the neuron count of the last layer must be 1, but got {}' .format (channels [- 1 ])
26
38
sys .exit (msg )
27
39
28
- super (PoseDiscriminator , self ).__init__ (fc_layers , use_dropout , drop_prob , use_ac_func )
40
+ self .conv_blocks = nn .Sequential ()
41
+ l = len (channels )
42
+ for idx in range (l - 2 ):
43
+ self .conv_blocks .add_module (
44
+ name = 'conv_{}' .format (idx ),
45
+ module = nn .Conv2d (in_channels = channels [idx ], out_channels = channels [idx + 1 ], kernel_size = 1 , stride = 1 )
46
+ )
29
47
48
+ self .fc_layer = nn .ModuleList ()
49
+ for idx in range (23 ):
50
+ self .fc_layer .append (nn .Linear (in_features = channels [l - 2 ], out_features = 1 ))
51
+
52
+ # N x 23 x 9
30
53
def forward (self , inputs ):
31
- '''
32
- x = self.fc_blocks(inputs)
33
- return [x, self.last_block(x)]
34
- '''
35
- return self .fc_blocks (inputs )
54
+ batch_size = inputs .shape [0 ]
55
+ inputs = inputs .transpose (1 , 2 ).unsqueeze (2 ) # to N x 9 x 1 x 23
56
+ internal_outputs = self .conv_blocks (inputs ) # to N x c x 1 x 23
57
+ o = []
58
+ for idx in range (23 ):
59
+ o .append (self .fc_layer [idx ](internal_outputs [:,:,0 ,idx ]))
60
+
61
+ return torch .cat (o , 1 ), internal_outputs
36
62
37
63
class FullPoseDiscriminator (LinearModel ):
38
64
def __init__ (self , fc_layers , use_dropout , drop_prob , use_ac_func ):
@@ -64,21 +90,16 @@ def _create_sub_modules(self):
64
90
'''
65
91
create theta discriminator for 23 joint
66
92
'''
67
- fc_layers = [9 , 32 , 32 , 1 ]
68
- use_dropout = [False , False , False ]
69
- drop_prob = [0.5 , 0.5 , 0.5 ]
70
- use_ac_func = [True , True , False ]
71
- self .pose_discriminators = nn .ModuleList ()
72
- for _ in range (self .joint_count - 1 ):
73
- self .pose_discriminators .append (PoseDiscriminator (fc_layers , use_dropout , drop_prob , use_ac_func ))
93
+
94
+ self .pose_discriminator = PoseDiscriminator ([9 , 32 , 32 , 1 ])
74
95
75
96
'''
76
97
create full pose discriminator for total 23 joints
77
98
'''
78
- fc_layers = [( self . joint_count - 1 ) * 9 , 1024 , 1024 , 1024 , 1 ]
79
- use_dropout = [False , False , False , False ]
80
- drop_prob = [0.5 , 0.5 , 0.5 , 0.5 ]
81
- use_ac_func = [True , True , True , False ]
99
+ fc_layers = [23 * 32 , 1024 , 1024 , 1 ]
100
+ use_dropout = [False , False , False ]
101
+ drop_prob = [0.5 , 0.5 , 0.5 ]
102
+ use_ac_func = [True , True , False ]
82
103
self .full_pose_discriminator = FullPoseDiscriminator (fc_layers , use_dropout , drop_prob , use_ac_func )
83
104
84
105
'''
@@ -92,67 +113,22 @@ def _create_sub_modules(self):
92
113
93
114
print ('finished create the discriminator modules...' )
94
115
95
- '''
96
- purpose:
97
- calc mean shape discriminator value
98
- inputs:
99
- real_shape N x 10
100
- fake_shape n x 10
101
- return:
102
- shape discriminator output value
103
- '''
104
- def calc_shape_disc_value (self , real_shape , fake_shape ):
105
- shapes = torch .cat ([real_shape , fake_shape ], dim = 0 )
106
- return self .shape_discriminator (shapes )
107
-
108
- '''
109
- inputs:
110
- real_pose N x 24 x 3
111
- fake_pose n x 24 x 3
112
- return:
113
- pose discriminator output value
114
- '''
115
- def calc_pose_disc_value (self , real_pose , fake_pose ):
116
- real_pose = util .batch_rodrigues (real_pose .view (- 1 , 3 )).view (- 1 , 24 , 9 )
117
- fake_pose = util .batch_rodrigues (fake_pose .view (- 1 , 3 )).view (- 1 , 24 , 9 )
118
- poses = torch .cat ((real_pose [:, 1 :, :], fake_pose [:, 1 :, :]), dim = 0 )
119
- full_pose_dis_value = self .full_pose_discriminator (poses .view (- 1 , 23 * 9 ))
120
- poses = torch .transpose (poses , 0 , 1 )
121
- theta_disc_values = []
122
- for _ in range (23 ):
123
- theta_disc_values .append (
124
- self .pose_discriminators [_ ](poses [_ , :, :])
125
- )
126
- pose_dis_value = torch .cat (theta_disc_values , dim = 1 )
127
- return torch .cat ([pose_dis_value , full_pose_dis_value ], dim = 1 )
128
116
129
117
'''
130
- inputs:
131
- real_thetas N x 85
132
- fake_thetas N x 85
133
- return
134
- pose & full pose & shape disc value N x (23 + 1 + 1)
118
+ inputs is N x 85(3 + 72 + 10)
135
119
'''
136
- def calc_thetas_disc_value (self , real_thetas , fake_thetas ):
137
- real_poses , fake_poses = real_thetas [:, 3 :75 ], fake_thetas [:, 3 :75 ]
138
- real_shapes , fake_shapes = real_thetas [:, 75 :], fake_thetas [:, 75 :]
139
- pose_disc_value = self .calc_pose_disc_value (real_poses .contiguous (), fake_poses .contiguous ())
140
- shape_disc_value = self .calc_shape_disc_value (real_shapes .contiguous (), fake_shapes .contiguous ())
141
- return torch .cat ([pose_disc_value , shape_disc_value ], dim = 1 )
142
-
143
- def forward (self , real_thetas , fake_thetas ):
144
- if config .args .normalize_disc :
145
- return F .sigmoid (self .calc_thetas_disc_value (real_thetas , fake_thetas ))
146
- else :
147
- return self .calc_thetas_disc_value (real_thetas , fake_thetas )
148
-
120
+ def forward (self , thetas ):
121
+ batch_size = thetas .shape [0 ]
122
+ cams , poses , shapes = thetas [:, :3 ], thetas [:, 3 :75 ], thetas [:, 75 :]
123
+ shape_disc_value = self .shape_discriminator (shapes )
124
+ rotate_matrixs = util .batch_rodrigues (poses .contiguous ().view (- 1 , 3 )).view (- 1 , 24 , 9 )[:, 1 :, :]
125
+ pose_disc_value , pose_inter_disc_value = self .pose_discriminator (rotate_matrixs )
126
+ full_pose_disc_value = self .full_pose_discriminator (pose_inter_disc_value .contiguous ().view (batch_size , - 1 ))
127
+ return torch .cat ((pose_disc_value , full_pose_disc_value , shape_disc_value ), 1 )
149
128
150
129
if __name__ == '__main__' :
151
130
device = torch .device ('cuda' )
152
- net = Discriminator ().to (device )
153
- real = torch .zeros ((100 , 85 )).float ().to (device )
154
- fake = torch .ones ((200 , 85 )).float ().to (device )
155
-
156
- dis_v = net (real , fake )
157
- print (dis_v .device )
158
- print (dis_v .shape )
131
+ net = Discriminator ()
132
+ inputs = torch .ones ((100 , 85 ))
133
+ disc_value = net (inputs )
134
+ print (net )
0 commit comments