Skip to content

Commit 0184396

Browse files
author
Maximilian Seitzer
committed
Style corrections and remove torch.Variable
1 parent ede3c37 commit 0184396

File tree

1 file changed

+27
-38
lines changed

1 file changed

+27
-38
lines changed

fid_score.py

Lines changed: 27 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,17 @@
3636
import pathlib
3737
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
3838

39-
import torch
4039
import numpy as np
41-
from scipy.misc import imread
40+
import torch
4241
from scipy import linalg
43-
from torch.autograd import Variable
42+
from scipy.misc import imread
4443
from torch.nn.functional import adaptive_avg_pool2d
45-
from tqdm import tqdm
44+
45+
try:
46+
from tqdm import tqdm
47+
except ImportError:
48+
# If not tqdm is not available, provide a mock version of it
49+
def tqdm(x): return x
4650

4751
from inception import InceptionV3
4852

@@ -65,8 +69,7 @@ def get_activations(files, model, batch_size=64, dims=2048,
6569
"""Calculates the activations of the pool_3 layer for all images.
6670
6771
Params:
68-
-- images : Numpy array of dimension (n_images, 3, hi, wi). The values
69-
must lie between 0 and 1.
72+
-- files : List of image files paths
7073
-- model : Instance of inception model
7174
-- batch_size : the images numpy array is split into batches with
7275
batch size batch_size. A reasonable batch size depends
@@ -81,34 +84,32 @@ def get_activations(files, model, batch_size=64, dims=2048,
8184
query tensor.
8285
"""
8386
model.eval()
84-
85-
#calculate number of total files
86-
d0 = len(files)
87-
if batch_size > d0:
87+
88+
if batch_size > len(files):
8889
print(('Warning: batch size is bigger than the data size. '
8990
'Setting batch size to data size'))
90-
batch_size = d0
91+
batch_size = len(files)
9192

92-
n_batches = d0 // batch_size
93+
n_batches = len(files) // batch_size
9394
n_used_imgs = n_batches * batch_size
9495

9596
pred_arr = np.empty((n_used_imgs, dims))
96-
97-
#Add processbar to know process
97+
9898
for i in tqdm(range(n_batches)):
9999
if verbose:
100100
print('\rPropagating batch %d/%d' % (i + 1, n_batches),
101101
end='', flush=True)
102102
start = i * batch_size
103103
end = start + batch_size
104-
105-
# real batch of images here
106-
images = np.array([imread(str(fn)).astype(np.float32) for fn in files[start:end]])
104+
105+
images = np.array([imread(str(f)).astype(np.float32)
106+
for f in files[start:end]])
107+
108+
# Reshape to (n_images, 3, height, width)
107109
images = images.transpose((0, 3, 1, 2))
108110
images /= 255
109-
111+
110112
batch = torch.from_numpy(images).type(torch.FloatTensor)
111-
batch = Variable(batch, volatile=True)
112113
if cuda:
113114
batch = batch.cuda()
114115

@@ -139,10 +140,10 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
139140
-- mu1 : Numpy array containing the activations of a layer of the
140141
inception net (like returned by the function 'get_predictions')
141142
for generated samples.
142-
-- mu2 : The sample mean over activations, precalculated on an
143+
-- mu2 : The sample mean over activations, precalculated on an
143144
representative data set.
144145
-- sigma1: The covariance matrix over activations for generated samples.
145-
-- sigma2: The covariance matrix over activations, precalculated on an
146+
-- sigma2: The covariance matrix over activations, precalculated on an
146147
representative data set.
147148
148149
Returns:
@@ -188,8 +189,7 @@ def calculate_activation_statistics(files, model, batch_size=64,
188189
dims=2048, cuda=False, verbose=False):
189190
"""Calculation of the statistics used by the FID.
190191
Params:
191-
-- images : Numpy array of dimension (n_images, 3, hi, wi). The values
192-
must lie between 0 and 1.
192+
-- files : List of image files paths
193193
-- model : Instance of inception model
194194
-- batch_size : The images numpy array is split into batches with
195195
batch size batch_size. A reasonable batch size
@@ -204,31 +204,20 @@ def calculate_activation_statistics(files, model, batch_size=64,
204204
-- sigma : The covariance matrix of the activations of the pool_3 layer of
205205
the inception model.
206206
"""
207-
# Instead of load all the images, we pass the file name list
208207
act = get_activations(files, model, batch_size, dims, cuda, verbose)
209208
mu = np.mean(act, axis=0)
210209
sigma = np.cov(act, rowvar=False)
211210
return mu, sigma
212211

213212

214-
def _compute_statistics_of_path(path, model, batch_size, dims, cuda, flag):
213+
def _compute_statistics_of_path(path, model, batch_size, dims, cuda):
215214
if path.endswith('.npz'):
216215
f = np.load(path)
217216
m, s = f['mu'][:], f['sigma'][:]
218217
f.close()
219218
else:
220219
path = pathlib.Path(path)
221220
files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
222-
223-
# imgs = np.array([imread(str(fn)).astype(np.float32) for fn in files])
224-
225-
# Bring images to shape (B, 3, H, W)
226-
# imgs = imgs.transpose((0, 3, 1, 2))
227-
228-
# Rescale images to be between 0 and 1
229-
# imgs /= 255
230-
231-
# Instead of load all the images, we pass the file name list
232221
m, s = calculate_activation_statistics(files, model, batch_size,
233222
dims, cuda)
234223

@@ -246,11 +235,11 @@ def calculate_fid_given_paths(paths, batch_size, cuda, dims):
246235
model = InceptionV3([block_idx])
247236
if cuda:
248237
model.cuda()
249-
238+
250239
m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size,
251-
dims, cuda, 1)
240+
dims, cuda)
252241
m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size,
253-
dims, cuda, 0)
242+
dims, cuda)
254243
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
255244

256245
return fid_value

0 commit comments

Comments
 (0)