Skip to content

Commit 230d16d

Browse files
author
Maximilian Seitzer
committed
Add fid score
1 parent c1c6c74 commit 230d16d

File tree

2 files changed

+398
-0
lines changed

2 files changed

+398
-0
lines changed

fid_score.py

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
#!/usr/bin/env python3
2+
"""Calculates the Frechet Inception Distance (FID) to evalulate GANs
3+
4+
The FID metric calculates the distance between two distributions of images.
5+
Typically, we have summary statistics (mean & covariance matrix) of one
6+
of these distributions, while the 2nd distribution is given by a GAN.
7+
8+
When run as a stand-alone program, it compares the distribution of
9+
images that are stored as PNG/JPEG at a specified location with a
10+
distribution given by summary statistics (in pickle format).
11+
12+
The FID is calculated by assuming that X_1 and X_2 are the activations of
13+
the pool_3 layer of the inception net for generated samples and real world
14+
samples respectivly.
15+
16+
See --help to see further details.
17+
18+
Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
19+
of Tensorflow
20+
21+
Copyright 2018 Institute of Bioinformatics, JKU Linz
22+
23+
Licensed under the Apache License, Version 2.0 (the "License");
24+
you may not use this file except in compliance with the License.
25+
You may obtain a copy of the License at
26+
27+
http://www.apache.org/licenses/LICENSE-2.0
28+
29+
Unless required by applicable law or agreed to in writing, software
30+
distributed under the License is distributed on an "AS IS" BASIS,
31+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32+
See the License for the specific language governing permissions and
33+
limitations under the License.
34+
"""
35+
import os
36+
import pathlib
37+
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
38+
39+
import torch
40+
import numpy as np
41+
from scipy.misc import imread
42+
from scipy import linalg
43+
from torch.autograd import Variable
44+
from torch.nn.functional import adaptive_avg_pool2d
45+
46+
from inception import InceptionV3
47+
48+
49+
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
50+
parser.add_argument('path', type=str, nargs=2,
51+
help=('Path to the generated images or '
52+
'to .npz statistic files'))
53+
parser.add_argument('--batch-size', type=int, default=64,
54+
help='Batch size to use')
55+
parser.add_argument('--dims', type=int, default=2048,
56+
choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
57+
help=('Dimensionality of Inception features to use. '
58+
'By default, uses pool3 features'))
59+
parser.add_argument('-c', '--gpu', default='', type=str,
60+
help='GPU to use (leave blank for CPU only)')
61+
62+
63+
def get_activations(images, model, batch_size=64, dims=2048,
64+
cuda=False, verbose=False):
65+
"""Calculates the activations of the pool_3 layer for all images.
66+
67+
Params:
68+
-- images : Numpy array of dimension (n_images, 3, hi, wi). The values
69+
must lie between 0 and 1.
70+
-- model : Instance of inception model
71+
-- batch_size : the images numpy array is split into batches with
72+
batch size batch_size. A reasonable batch size depends
73+
on the hardware.
74+
-- dims : Dimensionality of features returned by Inception
75+
-- cuda : If set to True, use GPU
76+
-- verbose : If set to True and parameter out_step is given, the number
77+
of calculated batches is reported.
78+
Returns:
79+
-- A numpy array of dimension (num images, dims) that contains the
80+
activations of the given tensor when feeding inception with the
81+
query tensor.
82+
"""
83+
model.eval()
84+
85+
d0 = images.shape[0]
86+
if batch_size > d0:
87+
print(('Warning: batch size is bigger than the data size. '
88+
'Setting batch size to data size'))
89+
batch_size = d0
90+
91+
n_batches = d0 // batch_size
92+
n_used_imgs = n_batches * batch_size
93+
94+
pred_arr = np.empty((n_used_imgs, dims))
95+
for i in range(n_batches):
96+
if verbose:
97+
print('\rPropagating batch %d/%d' % (i + 1, n_batches),
98+
end='', flush=True)
99+
start = i * batch_size
100+
end = start + batch_size
101+
102+
batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor)
103+
batch = Variable(batch, volatile=True)
104+
if cuda:
105+
batch = batch.cuda()
106+
107+
pred = model(batch)[0]
108+
109+
# If model output is not scalar, apply global spatial average pooling.
110+
# This happens if you choose a dimensionality not equal 2048.
111+
if pred.shape[2] != 1 or pred.shape[3] != 1:
112+
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
113+
114+
pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1)
115+
116+
if verbose:
117+
print(' done')
118+
119+
return pred_arr
120+
121+
122+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
123+
"""Numpy implementation of the Frechet Distance.
124+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
125+
and X_2 ~ N(mu_2, C_2) is
126+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
127+
128+
Stable version by Dougal J. Sutherland.
129+
130+
Params:
131+
-- mu1 : Numpy array containing the activations of a layer of the
132+
inception net (like returned by the function 'get_predictions')
133+
for generated samples.
134+
-- mu2 : The sample mean over activations, precalculated on an
135+
representive data set.
136+
-- sigma1: The covariance matrix over activations for generated samples.
137+
-- sigma2: The covariance matrix over activations, precalculated on an
138+
representive data set.
139+
140+
Returns:
141+
-- : The Frechet Distance.
142+
"""
143+
144+
mu1 = np.atleast_1d(mu1)
145+
mu2 = np.atleast_1d(mu2)
146+
147+
sigma1 = np.atleast_2d(sigma1)
148+
sigma2 = np.atleast_2d(sigma2)
149+
150+
assert mu1.shape == mu2.shape, \
151+
'Training and test mean vectors have different lengths'
152+
assert sigma1.shape == sigma2.shape, \
153+
'Training and test covariances have different dimensions'
154+
155+
diff = mu1 - mu2
156+
157+
# Product might be almost singular
158+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
159+
if not np.isfinite(covmean).all():
160+
msg = ('fid calculation produces singular product; '
161+
'adding %s to diagonal of cov estimates') % eps
162+
print(msg)
163+
offset = np.eye(sigma1.shape[0]) * eps
164+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
165+
166+
# Numerical error might give slight imaginary component
167+
if np.iscomplexobj(covmean):
168+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
169+
m = np.max(np.abs(covmean.imag))
170+
raise ValueError('Imaginary component {}'.format(m))
171+
covmean = covmean.real
172+
173+
tr_covmean = np.trace(covmean)
174+
175+
return (diff.dot(diff) + np.trace(sigma1) +
176+
np.trace(sigma2) - 2 * tr_covmean)
177+
178+
179+
def calculate_activation_statistics(images, model, batch_size=64,
180+
dims=2048, cuda=False, verbose=False):
181+
"""Calculation of the statistics used by the FID.
182+
Params:
183+
-- images : Numpy array of dimension (n_images, 3, hi, wi). The values
184+
must lie between 0 and 1.
185+
-- model : Instance of inception model
186+
-- batch_size : The images numpy array is split into batches with
187+
batch size batch_size. A reasonable batch size
188+
depends on the hardware.
189+
-- dims : Dimensionality of features returned by Inception
190+
-- cuda : If set to True, use GPU
191+
-- verbose : If set to True and parameter out_step is given, the
192+
number of calculated batches is reported.
193+
Returns:
194+
-- mu : The mean over samples of the activations of the pool_3 layer of
195+
the inception model.
196+
-- sigma : The covariance matrix of the activations of the pool_3 layer of
197+
the inception model.
198+
"""
199+
act = get_activations(images, model, batch_size, dims, cuda, verbose)
200+
mu = np.mean(act, axis=0)
201+
sigma = np.cov(act, rowvar=False)
202+
return mu, sigma
203+
204+
205+
def _compute_statistics_of_path(path, model, batch_size, dims, cuda):
206+
if path.endswith('.npz'):
207+
f = np.load(path)
208+
m, s = f['mu'][:], f['sigma'][:]
209+
f.close()
210+
else:
211+
path = pathlib.Path(path)
212+
files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
213+
214+
imgs = np.array([imread(str(fn)).astype(np.float32) for fn in files])
215+
216+
# Bring images to shape (B, 3, H, W)
217+
imgs = imgs.transpose((0, 3, 1, 2))
218+
219+
# Rescale images to be between 0 and 1
220+
imgs /= 255
221+
222+
m, s = calculate_activation_statistics(imgs, model, batch_size,
223+
dims, cuda)
224+
225+
return m, s
226+
227+
228+
def calculate_fid_given_paths(paths, batch_size, cuda, dims):
229+
"""Calculates the FID of two paths"""
230+
for p in paths:
231+
if not os.path.exists(p):
232+
raise RuntimeError('Invalid path: %s' % p)
233+
234+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
235+
236+
model = InceptionV3([block_idx])
237+
if cuda:
238+
model.cuda()
239+
240+
m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size,
241+
dims, cuda)
242+
m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size,
243+
dims, cuda)
244+
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
245+
246+
return fid_value
247+
248+
249+
if __name__ == '__main__':
250+
args = parser.parse_args()
251+
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
252+
253+
fid_value = calculate_fid_given_paths(args.path,
254+
args.batch_size,
255+
args.gpu != '',
256+
args.dims)
257+
print('FID: ', fid_value)

inception.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import torch.nn as nn
2+
import torch.nn.functional as F
3+
from torchvision import models
4+
5+
6+
class InceptionV3(nn.Module):
7+
"""Pretrained InceptionV3 network returning feature maps"""
8+
9+
# Index of default block of inception to return,
10+
# corresponds to output of final average pooling
11+
DEFAULT_BLOCK_INDEX = 3
12+
13+
# Maps feature dimensionality to their output blocks indices
14+
BLOCK_INDEX_BY_DIM = {
15+
64: 0, # First max pooling features
16+
192: 1, # Second max pooling featurs
17+
768: 2, # Pre-aux classifier features
18+
2048: 3 # Final average pooling features
19+
}
20+
21+
def __init__(self,
22+
output_blocks=[DEFAULT_BLOCK_INDEX],
23+
resize_input=True,
24+
normalize_input=True,
25+
requires_grad=False):
26+
"""Build pretrained InceptionV3
27+
28+
Parameters
29+
----------
30+
output_blocks : list of int
31+
Indices of blocks to return features of. Possible values are:
32+
- 0: corresponds to output of first max pooling
33+
- 1: corresponds to output of second max pooling
34+
- 2: corresponds to output which is fed to aux classifier
35+
- 3: corresponds to output of final average pooling
36+
resize_input : bool
37+
If true, bilinearly resizes input to width and height 299 before
38+
feeding input to model. As the network without fully connected
39+
layers is fully convolutional, it should be able to handle inputs
40+
of arbitrary size, so resizing might not be strictly needed
41+
normalize_input : bool
42+
If true, normalizes the input to the statistics the pretrained
43+
Inception network expects
44+
requires_grad : bool
45+
If true, parameters of the model require gradient. Possibly useful
46+
for finetuning the network
47+
"""
48+
super(InceptionV3, self).__init__()
49+
50+
self.resize_input = resize_input
51+
self.normalize_input = normalize_input
52+
self.output_blocks = sorted(output_blocks)
53+
self.last_needed_block = max(output_blocks)
54+
55+
assert self.last_needed_block <= 3, \
56+
'Last possible output block index is 3'
57+
58+
self.blocks = nn.ModuleList()
59+
60+
inception = models.inception_v3(pretrained=True)
61+
62+
# Block 0: input to maxpool1
63+
block0 = [
64+
inception.Conv2d_1a_3x3,
65+
inception.Conv2d_2a_3x3,
66+
inception.Conv2d_2b_3x3,
67+
nn.MaxPool2d(kernel_size=3, stride=2)
68+
]
69+
self.blocks.append(nn.Sequential(*block0))
70+
71+
# Block 1: maxpool1 to maxpool2
72+
if self.last_needed_block >= 1:
73+
block1 = [
74+
inception.Conv2d_3b_1x1,
75+
inception.Conv2d_4a_3x3,
76+
nn.MaxPool2d(kernel_size=3, stride=2)
77+
]
78+
self.blocks.append(nn.Sequential(*block1))
79+
80+
# Block 2: maxpool2 to aux classifier
81+
if self.last_needed_block >= 2:
82+
block2 = [
83+
inception.Mixed_5b,
84+
inception.Mixed_5c,
85+
inception.Mixed_5d,
86+
inception.Mixed_6a,
87+
inception.Mixed_6b,
88+
inception.Mixed_6c,
89+
inception.Mixed_6d,
90+
inception.Mixed_6e,
91+
]
92+
self.blocks.append(nn.Sequential(*block2))
93+
94+
# Block 3: aux classifier to final avgpool
95+
if self.last_needed_block >= 3:
96+
block3 = [
97+
inception.Mixed_7a,
98+
inception.Mixed_7b,
99+
inception.Mixed_7c,
100+
nn.AdaptiveAvgPool2d(output_size=(1, 1))
101+
]
102+
self.blocks.append(nn.Sequential(*block3))
103+
104+
for param in self.parameters():
105+
param.requires_grad = requires_grad
106+
107+
def forward(self, inp):
108+
"""Get Inception feature maps
109+
110+
Parameters
111+
----------
112+
inp : torch.autograd.Variable
113+
Input tensor of shape Bx3xHxW. Values are expected to be in
114+
range (0, 1)
115+
116+
Returns
117+
-------
118+
List of torch.autograd.Variable, corresponding to the selected output
119+
block, sorted ascending by index
120+
"""
121+
outp = []
122+
x = inp
123+
124+
if self.resize_input:
125+
x = F.upsample(x, size=(299, 299), mode='bilinear')
126+
127+
if self.normalize_input:
128+
x = x.clone()
129+
x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
130+
x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
131+
x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
132+
133+
for idx, block in enumerate(self.blocks):
134+
x = block(x)
135+
if idx in self.output_blocks:
136+
outp.append(x)
137+
138+
if idx == self.last_needed_block:
139+
break
140+
141+
return outp

0 commit comments

Comments
 (0)