Skip to content

Commit 53c4747

Browse files
committed
Batch validation batch size adjustment, tweak L2 crop pct
1 parent 08553e1 commit 53c4747

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

timm/models/efficientnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def _cfg(url='', **kwargs):
194194
input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936),
195195
'tf_efficientnet_l2_ns': _cfg(
196196
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',
197-
input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961),
197+
input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96),
198198
'tf_efficientnet_es': _cfg(
199199
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth',
200200
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),

validate.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,11 +211,24 @@ def main():
211211
logging.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
212212
results = []
213213
try:
214+
start_batch_size = args.batch_size
214215
for m, c in model_cfgs:
216+
batch_size = start_batch_size
215217
args.model = m
216218
args.checkpoint = c
217219
result = OrderedDict(model=args.model)
218-
r = validate(args)
220+
r = {}
221+
while not r and batch_size >= args.num_gpu:
222+
try:
223+
args.batch_size = batch_size
224+
print('Validating with batch size: %d' % args.batch_size)
225+
r = validate(args)
226+
except RuntimeError as e:
227+
if batch_size <= args.num_gpu:
228+
print("Validation failed with no ability to reduce batch size. Exiting.")
229+
raise e
230+
batch_size = max(batch_size // 2, args.num_gpu)
231+
print("Validation failed, reducing batch size by 50%")
219232
result.update(r)
220233
if args.checkpoint:
221234
result['checkpoint'] = args.checkpoint

0 commit comments

Comments
 (0)