30
30
metavar = 'N' , help = 'mini-batch size (default: 256)' )
31
31
parser .add_argument ('--img-size' , default = None , type = int ,
32
32
metavar = 'N' , help = 'Input image dimension, uses model default if empty' )
33
+ parser .add_argument ('--crop-pct' , default = None , type = float ,
34
+ metavar = 'N' , help = 'Input image center crop pct' )
33
35
parser .add_argument ('--mean' , type = float , nargs = '+' , default = None , metavar = 'MEAN' ,
34
36
help = 'Override mean pixel value of dataset' )
35
37
parser .add_argument ('--std' , type = float , nargs = '+' , default = None , metavar = 'STD' ,
@@ -81,6 +83,7 @@ def validate(args):
81
83
82
84
criterion = nn .CrossEntropyLoss ().cuda ()
83
85
86
+ crop_pct = 1.0 if test_time_pool else data_config ['crop_pct' ]
84
87
loader = create_loader (
85
88
Dataset (args .data , load_bytes = args .tf_preprocessing ),
86
89
input_size = data_config ['input_size' ],
@@ -90,7 +93,7 @@ def validate(args):
90
93
mean = data_config ['mean' ],
91
94
std = data_config ['std' ],
92
95
num_workers = args .workers ,
93
- crop_pct = 1.0 if test_time_pool else data_config [ ' crop_pct' ] ,
96
+ crop_pct = crop_pct ,
94
97
tf_preprocessing = args .tf_preprocessing )
95
98
96
99
batch_time = AverageMeter ()
@@ -124,16 +127,19 @@ def validate(args):
124
127
'Test: [{0:>4d}/{1}] '
125
128
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
126
129
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
127
- 'Prec@1: {top1.val:>7.4f } ({top1.avg:>7.4f }) '
128
- 'Prec@5: {top5.val:>7.4f } ({top5.avg:>7.4f })' .format (
130
+ 'Prec@1: {top1.val:>7.3f } ({top1.avg:>7.3f }) '
131
+ 'Prec@5: {top5.val:>7.3f } ({top5.avg:>7.3f })' .format (
129
132
i , len (loader ), batch_time = batch_time ,
130
133
rate_avg = input .size (0 ) / batch_time .avg ,
131
134
loss = losses , top1 = top1 , top5 = top5 ))
132
135
133
136
results = OrderedDict (
134
- top1 = round (top1 .avg , 3 ), top1_err = round (100 - top1 .avg , 3 ),
135
- top5 = round (top5 .avg , 3 ), top5_err = round (100 - top5 .avg , 3 ),
136
- param_count = round (param_count / 1e6 , 2 ))
137
+ top1 = round (top1 .avg , 4 ), top1_err = round (100 - top1 .avg , 4 ),
138
+ top5 = round (top5 .avg , 4 ), top5_err = round (100 - top5 .avg , 4 ),
139
+ param_count = round (param_count / 1e6 , 2 ),
140
+ img_size = data_config ['input_size' ][- 1 ],
141
+ cropt_pct = crop_pct ,
142
+ interpolation = data_config ['interpolation' ])
137
143
138
144
logging .info (' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})' .format (
139
145
results ['top1' ], results ['top1_err' ], results ['top5' ], results ['top5_err' ]))
@@ -155,7 +161,7 @@ def main():
155
161
if args .model == 'all' :
156
162
# validate all models in a list of names with pretrained checkpoints
157
163
args .pretrained = True
158
- model_names = list_models ()
164
+ model_names = list_models (pretrained = True )
159
165
model_cfgs = [(n , '' ) for n in model_names ]
160
166
elif not is_model (args .model ):
161
167
# model name doesn't exist, try as wildcard filter
@@ -170,7 +176,8 @@ def main():
170
176
args .model = m
171
177
args .checkpoint = c
172
178
result = OrderedDict (model = args .model )
173
- result .update (validate (args ))
179
+ r = validate (args )
180
+ result .update (r )
174
181
if args .checkpoint :
175
182
result ['checkpoint' ] = args .checkpoint
176
183
dw = csv .DictWriter (cf , fieldnames = result .keys ())
0 commit comments