Skip to content

Commit edb425e

Browse files
committed
Add crop_pct arg to validate, extra fields to csv output, 'all' filters pretrained
1 parent 949b7a8 commit edb425e

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

validate.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
metavar='N', help='mini-batch size (default: 256)')
3131
parser.add_argument('--img-size', default=None, type=int,
3232
metavar='N', help='Input image dimension, uses model default if empty')
33+
parser.add_argument('--crop-pct', default=None, type=float,
34+
metavar='N', help='Input image center crop pct')
3335
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
3436
help='Override mean pixel value of dataset')
3537
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
@@ -81,6 +83,7 @@ def validate(args):
8183

8284
criterion = nn.CrossEntropyLoss().cuda()
8385

86+
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
8487
loader = create_loader(
8588
Dataset(args.data, load_bytes=args.tf_preprocessing),
8689
input_size=data_config['input_size'],
@@ -90,7 +93,7 @@ def validate(args):
9093
mean=data_config['mean'],
9194
std=data_config['std'],
9295
num_workers=args.workers,
93-
crop_pct=1.0 if test_time_pool else data_config['crop_pct'],
96+
crop_pct=crop_pct,
9497
tf_preprocessing=args.tf_preprocessing)
9598

9699
batch_time = AverageMeter()
@@ -124,16 +127,19 @@ def validate(args):
124127
'Test: [{0:>4d}/{1}] '
125128
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
126129
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
127-
'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
128-
'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
130+
'Prec@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
131+
'Prec@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
129132
i, len(loader), batch_time=batch_time,
130133
rate_avg=input.size(0) / batch_time.avg,
131134
loss=losses, top1=top1, top5=top5))
132135

133136
results = OrderedDict(
134-
top1=round(top1.avg, 3), top1_err=round(100 - top1.avg, 3),
135-
top5=round(top5.avg, 3), top5_err=round(100 - top5.avg, 3),
136-
param_count=round(param_count / 1e6, 2))
137+
top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4),
138+
top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4),
139+
param_count=round(param_count / 1e6, 2),
140+
img_size=data_config['input_size'][-1],
141+
cropt_pct=crop_pct,
142+
interpolation=data_config['interpolation'])
137143

138144
logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format(
139145
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
@@ -155,7 +161,7 @@ def main():
155161
if args.model == 'all':
156162
# validate all models in a list of names with pretrained checkpoints
157163
args.pretrained = True
158-
model_names = list_models()
164+
model_names = list_models(pretrained=True)
159165
model_cfgs = [(n, '') for n in model_names]
160166
elif not is_model(args.model):
161167
# model name doesn't exist, try as wildcard filter
@@ -170,7 +176,8 @@ def main():
170176
args.model = m
171177
args.checkpoint = c
172178
result = OrderedDict(model=args.model)
173-
result.update(validate(args))
179+
r = validate(args)
180+
result.update(r)
174181
if args.checkpoint:
175182
result['checkpoint'] = args.checkpoint
176183
dw = csv.DictWriter(cf, fieldnames=result.keys())

0 commit comments

Comments
 (0)