Skip to content

Commit 8cab14a

Browse files
committed
read images in batch
1 parent 7475159 commit 8cab14a

File tree

2 files changed

+28
-18
lines changed

2 files changed

+28
-18
lines changed

__pycache__/inception.cpython-35.pyc

4.03 KB
Binary file not shown.

fid_score.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,15 @@
4242
from scipy import linalg
4343
from torch.autograd import Variable
4444
from torch.nn.functional import adaptive_avg_pool2d
45+
from tqdm import tqdm
4546

4647
from inception import InceptionV3
4748

48-
4949
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
5050
parser.add_argument('path', type=str, nargs=2,
5151
help=('Path to the generated images or '
5252
'to .npz statistic files'))
53-
parser.add_argument('--batch-size', type=int, default=64,
53+
parser.add_argument('--batch-size', type=int, default=256,
5454
help='Batch size to use')
5555
parser.add_argument('--dims', type=int, default=2048,
5656
choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
@@ -60,7 +60,7 @@
6060
help='GPU to use (leave blank for CPU only)')
6161

6262

63-
def get_activations(images, model, batch_size=64, dims=2048,
63+
def get_activations(files, model, batch_size=64, dims=2048,
6464
cuda=False, verbose=False):
6565
"""Calculates the activations of the pool_3 layer for all images.
6666
@@ -81,8 +81,9 @@ def get_activations(images, model, batch_size=64, dims=2048,
8181
query tensor.
8282
"""
8383
model.eval()
84-
85-
d0 = images.shape[0]
84+
85+
#calculate number of total files
86+
d0 = len(files)
8687
if batch_size > d0:
8788
print(('Warning: batch size is bigger than the data size. '
8889
'Setting batch size to data size'))
@@ -92,14 +93,21 @@ def get_activations(images, model, batch_size=64, dims=2048,
9293
n_used_imgs = n_batches * batch_size
9394

9495
pred_arr = np.empty((n_used_imgs, dims))
95-
for i in range(n_batches):
96+
97+
#Add processbar to know process
98+
for i in tqdm(range(n_batches)):
9699
if verbose:
97100
print('\rPropagating batch %d/%d' % (i + 1, n_batches),
98101
end='', flush=True)
99102
start = i * batch_size
100103
end = start + batch_size
101-
102-
batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor)
104+
105+
# real batch of images here
106+
images = np.array([imread(str(fn)).astype(np.float32) for fn in files[start:end]])
107+
images = images.transpose((0, 3, 1, 2))
108+
images /= 255
109+
110+
batch = torch.from_numpy(images).type(torch.FloatTensor)
103111
batch = Variable(batch, volatile=True)
104112
if cuda:
105113
batch = batch.cuda()
@@ -176,7 +184,7 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
176184
np.trace(sigma2) - 2 * tr_covmean)
177185

178186

179-
def calculate_activation_statistics(images, model, batch_size=64,
187+
def calculate_activation_statistics(files, model, batch_size=64,
180188
dims=2048, cuda=False, verbose=False):
181189
"""Calculation of the statistics used by the FID.
182190
Params:
@@ -196,13 +204,14 @@ def calculate_activation_statistics(images, model, batch_size=64,
196204
-- sigma : The covariance matrix of the activations of the pool_3 layer of
197205
the inception model.
198206
"""
199-
act = get_activations(images, model, batch_size, dims, cuda, verbose)
207+
# Instead of load all the images, we pass the file name list
208+
act = get_activations(files, model, batch_size, dims, cuda, verbose)
200209
mu = np.mean(act, axis=0)
201210
sigma = np.cov(act, rowvar=False)
202211
return mu, sigma
203212

204213

205-
def _compute_statistics_of_path(path, model, batch_size, dims, cuda):
214+
def _compute_statistics_of_path(path, model, batch_size, dims, cuda, flag):
206215
if path.endswith('.npz'):
207216
f = np.load(path)
208217
m, s = f['mu'][:], f['sigma'][:]
@@ -211,15 +220,16 @@ def _compute_statistics_of_path(path, model, batch_size, dims, cuda):
211220
path = pathlib.Path(path)
212221
files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
213222

214-
imgs = np.array([imread(str(fn)).astype(np.float32) for fn in files])
223+
# imgs = np.array([imread(str(fn)).astype(np.float32) for fn in files])
215224

216225
# Bring images to shape (B, 3, H, W)
217-
imgs = imgs.transpose((0, 3, 1, 2))
226+
# imgs = imgs.transpose((0, 3, 1, 2))
218227

219228
# Rescale images to be between 0 and 1
220-
imgs /= 255
229+
# imgs /= 255
221230

222-
m, s = calculate_activation_statistics(imgs, model, batch_size,
231+
# Instead of load all the images, we pass the file name list
232+
m, s = calculate_activation_statistics(files, model, batch_size,
223233
dims, cuda)
224234

225235
return m, s
@@ -236,11 +246,11 @@ def calculate_fid_given_paths(paths, batch_size, cuda, dims):
236246
model = InceptionV3([block_idx])
237247
if cuda:
238248
model.cuda()
239-
249+
240250
m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size,
241-
dims, cuda)
251+
dims, cuda, 1)
242252
m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size,
243-
dims, cuda)
253+
dims, cuda, 0)
244254
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
245255

246256
return fid_value

0 commit comments

Comments
 (0)