42
42
from scipy import linalg
43
43
from torch .autograd import Variable
44
44
from torch .nn .functional import adaptive_avg_pool2d
45
+ from tqdm import tqdm
45
46
46
47
from inception import InceptionV3
47
48
48
-
49
49
parser = ArgumentParser (formatter_class = ArgumentDefaultsHelpFormatter )
50
50
parser .add_argument ('path' , type = str , nargs = 2 ,
51
51
help = ('Path to the generated images or '
52
52
'to .npz statistic files' ))
53
- parser .add_argument ('--batch-size' , type = int , default = 64 ,
53
+ parser .add_argument ('--batch-size' , type = int , default = 256 ,
54
54
help = 'Batch size to use' )
55
55
parser .add_argument ('--dims' , type = int , default = 2048 ,
56
56
choices = list (InceptionV3 .BLOCK_INDEX_BY_DIM ),
60
60
help = 'GPU to use (leave blank for CPU only)' )
61
61
62
62
63
- def get_activations (images , model , batch_size = 64 , dims = 2048 ,
63
+ def get_activations (files , model , batch_size = 64 , dims = 2048 ,
64
64
cuda = False , verbose = False ):
65
65
"""Calculates the activations of the pool_3 layer for all images.
66
66
@@ -81,8 +81,9 @@ def get_activations(images, model, batch_size=64, dims=2048,
81
81
query tensor.
82
82
"""
83
83
model .eval ()
84
-
85
- d0 = images .shape [0 ]
84
+
85
+ #calculate number of total files
86
+ d0 = len (files )
86
87
if batch_size > d0 :
87
88
print (('Warning: batch size is bigger than the data size. '
88
89
'Setting batch size to data size' ))
@@ -92,14 +93,21 @@ def get_activations(images, model, batch_size=64, dims=2048,
92
93
n_used_imgs = n_batches * batch_size
93
94
94
95
pred_arr = np .empty ((n_used_imgs , dims ))
95
- for i in range (n_batches ):
96
+
97
+ #Add processbar to know process
98
+ for i in tqdm (range (n_batches )):
96
99
if verbose :
97
100
print ('\r Propagating batch %d/%d' % (i + 1 , n_batches ),
98
101
end = '' , flush = True )
99
102
start = i * batch_size
100
103
end = start + batch_size
101
-
102
- batch = torch .from_numpy (images [start :end ]).type (torch .FloatTensor )
104
+
105
+ # real batch of images here
106
+ images = np .array ([imread (str (fn )).astype (np .float32 ) for fn in files [start :end ]])
107
+ images = images .transpose ((0 , 3 , 1 , 2 ))
108
+ images /= 255
109
+
110
+ batch = torch .from_numpy (images ).type (torch .FloatTensor )
103
111
batch = Variable (batch , volatile = True )
104
112
if cuda :
105
113
batch = batch .cuda ()
@@ -176,7 +184,7 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
176
184
np .trace (sigma2 ) - 2 * tr_covmean )
177
185
178
186
179
- def calculate_activation_statistics (images , model , batch_size = 64 ,
187
+ def calculate_activation_statistics (files , model , batch_size = 64 ,
180
188
dims = 2048 , cuda = False , verbose = False ):
181
189
"""Calculation of the statistics used by the FID.
182
190
Params:
@@ -196,13 +204,14 @@ def calculate_activation_statistics(images, model, batch_size=64,
196
204
-- sigma : The covariance matrix of the activations of the pool_3 layer of
197
205
the inception model.
198
206
"""
199
- act = get_activations (images , model , batch_size , dims , cuda , verbose )
207
+ # Instead of load all the images, we pass the file name list
208
+ act = get_activations (files , model , batch_size , dims , cuda , verbose )
200
209
mu = np .mean (act , axis = 0 )
201
210
sigma = np .cov (act , rowvar = False )
202
211
return mu , sigma
203
212
204
213
205
- def _compute_statistics_of_path (path , model , batch_size , dims , cuda ):
214
+ def _compute_statistics_of_path (path , model , batch_size , dims , cuda , flag ):
206
215
if path .endswith ('.npz' ):
207
216
f = np .load (path )
208
217
m , s = f ['mu' ][:], f ['sigma' ][:]
@@ -211,15 +220,16 @@ def _compute_statistics_of_path(path, model, batch_size, dims, cuda):
211
220
path = pathlib .Path (path )
212
221
files = list (path .glob ('*.jpg' )) + list (path .glob ('*.png' ))
213
222
214
- imgs = np .array ([imread (str (fn )).astype (np .float32 ) for fn in files ])
223
+ # imgs = np.array([imread(str(fn)).astype(np.float32) for fn in files])
215
224
216
225
# Bring images to shape (B, 3, H, W)
217
- imgs = imgs .transpose ((0 , 3 , 1 , 2 ))
226
+ # imgs = imgs.transpose((0, 3, 1, 2))
218
227
219
228
# Rescale images to be between 0 and 1
220
- imgs /= 255
229
+ # imgs /= 255
221
230
222
- m , s = calculate_activation_statistics (imgs , model , batch_size ,
231
+ # Instead of load all the images, we pass the file name list
232
+ m , s = calculate_activation_statistics (files , model , batch_size ,
223
233
dims , cuda )
224
234
225
235
return m , s
@@ -236,11 +246,11 @@ def calculate_fid_given_paths(paths, batch_size, cuda, dims):
236
246
model = InceptionV3 ([block_idx ])
237
247
if cuda :
238
248
model .cuda ()
239
-
249
+
240
250
m1 , s1 = _compute_statistics_of_path (paths [0 ], model , batch_size ,
241
- dims , cuda )
251
+ dims , cuda , 1 )
242
252
m2 , s2 = _compute_statistics_of_path (paths [1 ], model , batch_size ,
243
- dims , cuda )
253
+ dims , cuda , 0 )
244
254
fid_value = calculate_frechet_distance (m1 , s1 , m2 , s2 )
245
255
246
256
return fid_value
0 commit comments