Skip to content

Commit 6cdf35e

Browse files
committed
Add explicit half/fp16 support to loader and validation script
1 parent 5684c6a commit 6cdf35e

File tree

4 files changed

+32
-15
lines changed

4 files changed

+32
-15
lines changed

timm/data/loader.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,15 @@ def __init__(self,
2121
rand_erase_prob=0.,
2222
rand_erase_mode='const',
2323
mean=IMAGENET_DEFAULT_MEAN,
24-
std=IMAGENET_DEFAULT_STD):
24+
std=IMAGENET_DEFAULT_STD,
25+
fp16=False):
2526
self.loader = loader
2627
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
2728
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
29+
self.fp16 = fp16
30+
if fp16:
31+
self.mean = self.mean.half()
32+
self.std = self.std.half()
2833
if rand_erase_prob > 0.:
2934
self.random_erasing = RandomErasing(
3035
probability=rand_erase_prob, mode=rand_erase_mode)
@@ -39,7 +44,10 @@ def __iter__(self):
3944
with torch.cuda.stream(stream):
4045
next_input = next_input.cuda(non_blocking=True)
4146
next_target = next_target.cuda(non_blocking=True)
42-
next_input = next_input.float().sub_(self.mean).div_(self.std)
47+
if self.fp16:
48+
next_input = next_input.half().sub_(self.mean).div_(self.std)
49+
else:
50+
next_input = next_input.float().sub_(self.mean).div_(self.std)
4351
if self.random_erasing is not None:
4452
next_input = self.random_erasing(next_input)
4553

@@ -94,6 +102,7 @@ def create_loader(
94102
distributed=False,
95103
crop_pct=None,
96104
collate_fn=None,
105+
fp16=False,
97106
tf_preprocessing=False,
98107
):
99108
if isinstance(input_size, tuple):
@@ -151,6 +160,7 @@ def create_loader(
151160
rand_erase_prob=rand_erase_prob if is_training else 0.,
152161
rand_erase_mode=rand_erase_mode,
153162
mean=mean,
154-
std=std)
163+
std=std,
164+
fp16=fp16)
155165

156166
return loader

timm/utils.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,7 @@ def accuracy(output, target, topk=(1,)):
156156
_, pred = output.topk(maxk, 1, True, True)
157157
pred = pred.t()
158158
correct = pred.eq(target.view(1, -1).expand_as(pred))
159-
160-
res = []
161-
for k in topk:
162-
correct_k = correct[:k].view(-1).float().sum(0)
163-
res.append(correct_k.mul_(100.0 / batch_size))
164-
return res
159+
return [correct[:k].view(-1).float().sum(0) * 100. / batch_size for k in topk]
165160

166161

167162
def get_outdir(path, *paths, inc=False):

timm/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.1.7'
1+
__version__ = '0.1.8'

validate.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@
5050
help='Number of GPUS to use')
5151
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
5252
help='disable test time pool')
53-
parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true',
53+
parser.add_argument('--no-prefetcher', action='store_true', default=False,
54+
help='disable fast prefetcher')
55+
parser.add_argument('--fp16', action='store_true', default=False,
56+
help='Use half precision (fp16)')
57+
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
5458
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
5559
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
5660
help='use ema version of weights if present')
@@ -59,6 +63,7 @@
5963
def validate(args):
6064
# might as well try to validate something
6165
args.pretrained = args.pretrained or not args.checkpoint
66+
args.prefetcher = not args.no_prefetcher
6267

6368
# create model
6469
model = create_model(
@@ -81,19 +86,23 @@ def validate(args):
8186
else:
8287
model = model.cuda()
8388

89+
if args.fp16:
90+
model = model.half()
91+
8492
criterion = nn.CrossEntropyLoss().cuda()
8593

8694
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
8795
loader = create_loader(
8896
Dataset(args.data, load_bytes=args.tf_preprocessing),
8997
input_size=data_config['input_size'],
9098
batch_size=args.batch_size,
91-
use_prefetcher=True,
99+
use_prefetcher=args.prefetcher,
92100
interpolation=data_config['interpolation'],
93101
mean=data_config['mean'],
94102
std=data_config['std'],
95103
num_workers=args.workers,
96104
crop_pct=crop_pct,
105+
fp16=args.fp16,
97106
tf_preprocessing=args.tf_preprocessing)
98107

99108
batch_time = AverageMeter()
@@ -105,8 +114,11 @@ def validate(args):
105114
end = time.time()
106115
with torch.no_grad():
107116
for i, (input, target) in enumerate(loader):
108-
target = target.cuda()
109-
input = input.cuda()
117+
if args.no_prefetcher:
118+
target = target.cuda()
119+
input = input.cuda()
120+
if args.fp16:
121+
input = input.half()
110122

111123
# compute output
112124
output = model(input)
@@ -125,7 +137,7 @@ def validate(args):
125137
if i % args.log_freq == 0:
126138
logging.info(
127139
'Test: [{0:>4d}/{1}] '
128-
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
140+
'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
129141
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
130142
'Prec@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
131143
'Prec@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(

0 commit comments

Comments
 (0)