50
50
help = 'Number of GPUS to use' )
51
51
parser .add_argument ('--no-test-pool' , dest = 'no_test_pool' , action = 'store_true' ,
52
52
help = 'disable test time pool' )
53
- parser .add_argument ('--tf-preprocessing' , dest = 'tf_preprocessing' , action = 'store_true' ,
53
+ parser .add_argument ('--no-prefetcher' , action = 'store_true' , default = False ,
54
+ help = 'disable fast prefetcher' )
55
+ parser .add_argument ('--fp16' , action = 'store_true' , default = False ,
56
+ help = 'Use half precision (fp16)' )
57
+ parser .add_argument ('--tf-preprocessing' , action = 'store_true' , default = False ,
54
58
help = 'Use Tensorflow preprocessing pipeline (require CPU TF installed' )
55
59
parser .add_argument ('--use-ema' , dest = 'use_ema' , action = 'store_true' ,
56
60
help = 'use ema version of weights if present' )
59
63
def validate (args ):
60
64
# might as well try to validate something
61
65
args .pretrained = args .pretrained or not args .checkpoint
66
+ args .prefetcher = not args .no_prefetcher
62
67
63
68
# create model
64
69
model = create_model (
@@ -81,19 +86,23 @@ def validate(args):
81
86
else :
82
87
model = model .cuda ()
83
88
89
+ if args .fp16 :
90
+ model = model .half ()
91
+
84
92
criterion = nn .CrossEntropyLoss ().cuda ()
85
93
86
94
crop_pct = 1.0 if test_time_pool else data_config ['crop_pct' ]
87
95
loader = create_loader (
88
96
Dataset (args .data , load_bytes = args .tf_preprocessing ),
89
97
input_size = data_config ['input_size' ],
90
98
batch_size = args .batch_size ,
91
- use_prefetcher = True ,
99
+ use_prefetcher = args . prefetcher ,
92
100
interpolation = data_config ['interpolation' ],
93
101
mean = data_config ['mean' ],
94
102
std = data_config ['std' ],
95
103
num_workers = args .workers ,
96
104
crop_pct = crop_pct ,
105
+ fp16 = args .fp16 ,
97
106
tf_preprocessing = args .tf_preprocessing )
98
107
99
108
batch_time = AverageMeter ()
@@ -105,8 +114,11 @@ def validate(args):
105
114
end = time .time ()
106
115
with torch .no_grad ():
107
116
for i , (input , target ) in enumerate (loader ):
108
- target = target .cuda ()
109
- input = input .cuda ()
117
+ if args .no_prefetcher :
118
+ target = target .cuda ()
119
+ input = input .cuda ()
120
+ if args .fp16 :
121
+ input = input .half ()
110
122
111
123
# compute output
112
124
output = model (input )
@@ -125,7 +137,7 @@ def validate(args):
125
137
if i % args .log_freq == 0 :
126
138
logging .info (
127
139
'Test: [{0:>4d}/{1}] '
128
- 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
140
+ 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s ) '
129
141
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
130
142
'Prec@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
131
143
'Prec@5: {top5.val:>7.3f} ({top5.avg:>7.3f})' .format (
0 commit comments