Skip to content

Commit 40e7e26

Browse files
author
Maximilian Seitzer
committed
Set batch size to 50 to match TF FID score implementation
1 parent 0184396 commit 40e7e26

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

fid_score.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def tqdm(x): return x
5454
parser.add_argument('path', type=str, nargs=2,
5555
help=('Path to the generated images or '
5656
'to .npz statistic files'))
57-
parser.add_argument('--batch-size', type=int, default=256,
57+
parser.add_argument('--batch-size', type=int, default=50,
5858
help='Batch size to use')
5959
parser.add_argument('--dims', type=int, default=2048,
6060
choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
@@ -64,16 +64,18 @@ def tqdm(x): return x
6464
help='GPU to use (leave blank for CPU only)')
6565

6666

67-
def get_activations(files, model, batch_size=64, dims=2048,
67+
def get_activations(files, model, batch_size=50, dims=2048,
6868
cuda=False, verbose=False):
6969
"""Calculates the activations of the pool_3 layer for all images.
7070
7171
Params:
7272
-- files : List of image files paths
7373
-- model : Instance of inception model
74-
-- batch_size : the images numpy array is split into batches with
75-
batch size batch_size. A reasonable batch size depends
76-
on the hardware.
74+
-- batch_size : Batch size of images for the model to process at once.
75+
Make sure that the number of samples is a multiple of
76+
the batch size, otherwise some samples are ignored. This
77+
behavior is retained to match the original FID score
78+
implementation.
7779
-- dims : Dimensionality of features returned by Inception
7880
-- cuda : If set to True, use GPU
7981
-- verbose : If set to True and parameter out_step is given, the number
@@ -85,6 +87,9 @@ def get_activations(files, model, batch_size=64, dims=2048,
8587
"""
8688
model.eval()
8789

90+
if len(files) % batch_size != 0:
91+
print(('Warning: number of images is not a multiple of the '
92+
'batch size. Some samples are going to be ignored.'))
8893
if batch_size > len(files):
8994
print(('Warning: batch size is bigger than the data size. '
9095
'Setting batch size to data size'))
@@ -185,7 +190,7 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
185190
np.trace(sigma2) - 2 * tr_covmean)
186191

187192

188-
def calculate_activation_statistics(files, model, batch_size=64,
193+
def calculate_activation_statistics(files, model, batch_size=50,
189194
dims=2048, cuda=False, verbose=False):
190195
"""Calculation of the statistics used by the FID.
191196
Params:

0 commit comments

Comments
 (0)