Skip to content

Commit dfaab97

Browse files
committed
More consistency in model arg/kwarg merge handling
1 parent 3370053 commit dfaab97

File tree

9 files changed

+154
-163
lines changed

9 files changed

+154
-163
lines changed

timm/models/davit.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -648,41 +648,35 @@ def _cfg(url='', **kwargs):
648648

649649
@register_model
650650
def davit_tiny(pretrained=False, **kwargs) -> DaVit:
651-
model_kwargs = dict(
652-
depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24), **kwargs)
653-
return _create_davit('davit_tiny', pretrained=pretrained, **model_kwargs)
651+
model_args = dict(depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24))
652+
return _create_davit('davit_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
654653

655654

656655
@register_model
657656
def davit_small(pretrained=False, **kwargs) -> DaVit:
658-
model_kwargs = dict(
659-
depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24), **kwargs)
660-
return _create_davit('davit_small', pretrained=pretrained, **model_kwargs)
657+
model_args = dict(depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24))
658+
return _create_davit('davit_small', pretrained=pretrained, **dict(model_args, **kwargs))
661659

662660

663661
@register_model
664662
def davit_base(pretrained=False, **kwargs) -> DaVit:
665-
model_kwargs = dict(
666-
depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), num_heads=(4, 8, 16, 32), **kwargs)
667-
return _create_davit('davit_base', pretrained=pretrained, **model_kwargs)
663+
model_args = dict(depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), num_heads=(4, 8, 16, 32))
664+
return _create_davit('davit_base', pretrained=pretrained, **dict(model_args, **kwargs))
668665

669666

670667
@register_model
671668
def davit_large(pretrained=False, **kwargs) -> DaVit:
672-
model_kwargs = dict(
673-
depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536), num_heads=(6, 12, 24, 48), **kwargs)
674-
return _create_davit('davit_large', pretrained=pretrained, **model_kwargs)
669+
model_args = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536), num_heads=(6, 12, 24, 48))
670+
return _create_davit('davit_large', pretrained=pretrained, **dict(model_args, **kwargs))
675671

676672

677673
@register_model
678674
def davit_huge(pretrained=False, **kwargs) -> DaVit:
679-
model_kwargs = dict(
680-
depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), num_heads=(8, 16, 32, 64), **kwargs)
681-
return _create_davit('davit_huge', pretrained=pretrained, **model_kwargs)
675+
model_args = dict(depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), num_heads=(8, 16, 32, 64))
676+
return _create_davit('davit_huge', pretrained=pretrained, **dict(model_args, **kwargs))
682677

683678

684679
@register_model
685680
def davit_giant(pretrained=False, **kwargs) -> DaVit:
686-
model_kwargs = dict(
687-
depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), num_heads=(12, 24, 48, 96), **kwargs)
688-
return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs)
681+
model_args = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), num_heads=(12, 24, 48, 96))
682+
return _create_davit('davit_giant', pretrained=pretrained, **dict(model_args, **kwargs))

timm/models/densenet.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,8 @@ def densenet121(pretrained=False, **kwargs) -> DenseNet:
361361
r"""Densenet-121 model from
362362
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
363363
"""
364-
model = _create_densenet(
365-
'densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, **kwargs)
364+
model_args = dict(growth_rate=32, block_config=(6, 12, 24, 16))
365+
model = _create_densenet('densenet121', pretrained=pretrained, **dict(model_args, **kwargs))
366366
return model
367367

368368

@@ -371,9 +371,8 @@ def densenetblur121d(pretrained=False, **kwargs) -> DenseNet:
371371
r"""Densenet-121 w/ blur-pooling & 3-layer 3x3 stem
372372
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
373373
"""
374-
model = _create_densenet(
375-
'densenetblur121d', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained,
376-
stem_type='deep', aa_layer=BlurPool2d, **kwargs)
374+
model_args = dict(growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep', aa_layer=BlurPool2d)
375+
model = _create_densenet('densenetblur121d', pretrained=pretrained, **dict(model_args, **kwargs))
377376
return model
378377

379378

@@ -382,8 +381,8 @@ def densenet169(pretrained=False, **kwargs) -> DenseNet:
382381
r"""Densenet-169 model from
383382
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
384383
"""
385-
model = _create_densenet(
386-
'densenet169', growth_rate=32, block_config=(6, 12, 32, 32), pretrained=pretrained, **kwargs)
384+
model_args = dict(growth_rate=32, block_config=(6, 12, 32, 32))
385+
model = _create_densenet('densenet169', pretrained=pretrained, **dict(model_args, **kwargs))
387386
return model
388387

389388

@@ -392,8 +391,8 @@ def densenet201(pretrained=False, **kwargs) -> DenseNet:
392391
r"""Densenet-201 model from
393392
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
394393
"""
395-
model = _create_densenet(
396-
'densenet201', growth_rate=32, block_config=(6, 12, 48, 32), pretrained=pretrained, **kwargs)
394+
model_args = dict(growth_rate=32, block_config=(6, 12, 48, 32))
395+
model = _create_densenet('densenet201', pretrained=pretrained, **dict(model_args, **kwargs))
397396
return model
398397

399398

@@ -402,8 +401,8 @@ def densenet161(pretrained=False, **kwargs) -> DenseNet:
402401
r"""Densenet-161 model from
403402
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
404403
"""
405-
model = _create_densenet(
406-
'densenet161', growth_rate=48, block_config=(6, 12, 36, 24), pretrained=pretrained, **kwargs)
404+
model_args = dict(growth_rate=48, block_config=(6, 12, 36, 24))
405+
model = _create_densenet('densenet161', pretrained=pretrained, **dict(model_args, **kwargs))
407406
return model
408407

409408

@@ -412,7 +411,7 @@ def densenet264d(pretrained=False, **kwargs) -> DenseNet:
412411
r"""Densenet-264 model from
413412
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
414413
"""
415-
model = _create_densenet(
416-
'densenet264d', growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep', pretrained=pretrained, **kwargs)
414+
model_args = dict(growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep')
415+
model = _create_densenet('densenet264d', pretrained=pretrained, **dict(model_args, **kwargs))
417416
return model
418417

timm/models/dpn.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -317,55 +317,55 @@ def _cfg(url='', **kwargs):
317317

318318
@register_model
319319
def dpn48b(pretrained=False, **kwargs) -> DPN:
320-
model_kwargs = dict(
320+
model_args = dict(
321321
small=True, num_init_features=10, k_r=128, groups=32,
322322
b=True, k_sec=(3, 4, 6, 3), inc_sec=(16, 32, 32, 64), act_layer='silu')
323-
return _create_dpn('dpn48b', pretrained=pretrained, **dict(model_kwargs, **kwargs))
323+
return _create_dpn('dpn48b', pretrained=pretrained, **dict(model_args, **kwargs))
324324

325325

326326
@register_model
327327
def dpn68(pretrained=False, **kwargs) -> DPN:
328-
model_kwargs = dict(
328+
model_args = dict(
329329
small=True, num_init_features=10, k_r=128, groups=32,
330330
k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64))
331-
return _create_dpn('dpn68', pretrained=pretrained, **dict(model_kwargs, **kwargs))
331+
return _create_dpn('dpn68', pretrained=pretrained, **dict(model_args, **kwargs))
332332

333333

334334
@register_model
335335
def dpn68b(pretrained=False, **kwargs) -> DPN:
336-
model_kwargs = dict(
336+
model_args = dict(
337337
small=True, num_init_features=10, k_r=128, groups=32,
338338
b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64))
339-
return _create_dpn('dpn68b', pretrained=pretrained, **dict(model_kwargs, **kwargs))
339+
return _create_dpn('dpn68b', pretrained=pretrained, **dict(model_args, **kwargs))
340340

341341

342342
@register_model
343343
def dpn92(pretrained=False, **kwargs) -> DPN:
344-
model_kwargs = dict(
344+
model_args = dict(
345345
num_init_features=64, k_r=96, groups=32,
346346
k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128))
347-
return _create_dpn('dpn92', pretrained=pretrained, **dict(model_kwargs, **kwargs))
347+
return _create_dpn('dpn92', pretrained=pretrained, **dict(model_args, **kwargs))
348348

349349

350350
@register_model
351351
def dpn98(pretrained=False, **kwargs) -> DPN:
352-
model_kwargs = dict(
352+
model_args = dict(
353353
num_init_features=96, k_r=160, groups=40,
354354
k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128))
355-
return _create_dpn('dpn98', pretrained=pretrained, **dict(model_kwargs, **kwargs))
355+
return _create_dpn('dpn98', pretrained=pretrained, **dict(model_args, **kwargs))
356356

357357

358358
@register_model
359359
def dpn131(pretrained=False, **kwargs) -> DPN:
360-
model_kwargs = dict(
360+
model_args = dict(
361361
num_init_features=128, k_r=160, groups=40,
362362
k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128))
363-
return _create_dpn('dpn131', pretrained=pretrained, **dict(model_kwargs, **kwargs))
363+
return _create_dpn('dpn131', pretrained=pretrained, **dict(model_args, **kwargs))
364364

365365

366366
@register_model
367367
def dpn107(pretrained=False, **kwargs) -> DPN:
368-
model_kwargs = dict(
368+
model_args = dict(
369369
num_init_features=128, k_r=200, groups=50,
370370
k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128))
371-
return _create_dpn('dpn107', pretrained=pretrained, **dict(model_kwargs, **kwargs))
371+
return _create_dpn('dpn107', pretrained=pretrained, **dict(model_args, **kwargs))

timm/models/edgenext.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -528,8 +528,8 @@ def edgenext_xx_small(pretrained=False, **kwargs) -> EdgeNeXt:
528528
# No AA, Color Jitter=0.4, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler
529529
# Jetson FPS=51.66 versus 47.67 for MobileViT_XXS
530530
# For A100: FPS @ BS=1: 212.13 & @ BS=256: 7042.06 versus FPS @ BS=1: 96.68 & @ BS=256: 4624.71 for MobileViT_XXS
531-
model_kwargs = dict(depths=(2, 2, 6, 2), dims=(24, 48, 88, 168), heads=(4, 4, 4, 4), **kwargs)
532-
return _create_edgenext('edgenext_xx_small', pretrained=pretrained, **model_kwargs)
531+
model_args = dict(depths=(2, 2, 6, 2), dims=(24, 48, 88, 168), heads=(4, 4, 4, 4))
532+
return _create_edgenext('edgenext_xx_small', pretrained=pretrained, **dict(model_args, **kwargs))
533533

534534

535535
@register_model
@@ -539,8 +539,8 @@ def edgenext_x_small(pretrained=False, **kwargs) -> EdgeNeXt:
539539
# No AA, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler
540540
# Jetson FPS=31.61 versus 28.49 for MobileViT_XS
541541
# For A100: FPS @ BS=1: 179.55 & @ BS=256: 4404.95 versus FPS @ BS=1: 94.55 & @ BS=256: 2361.53 for MobileViT_XS
542-
model_kwargs = dict(depths=(3, 3, 9, 3), dims=(32, 64, 100, 192), heads=(4, 4, 4, 4), **kwargs)
543-
return _create_edgenext('edgenext_x_small', pretrained=pretrained, **model_kwargs)
542+
model_args = dict(depths=(3, 3, 9, 3), dims=(32, 64, 100, 192), heads=(4, 4, 4, 4))
543+
return _create_edgenext('edgenext_x_small', pretrained=pretrained, **dict(model_args, **kwargs))
544544

545545

546546
@register_model
@@ -550,8 +550,8 @@ def edgenext_small(pretrained=False, **kwargs) -> EdgeNeXt:
550550
# AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
551551
# Jetson FPS=20.47 versus 18.86 for MobileViT_S
552552
# For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S
553-
model_kwargs = dict(depths=(3, 3, 9, 3), dims=(48, 96, 160, 304), **kwargs)
554-
return _create_edgenext('edgenext_small', pretrained=pretrained, **model_kwargs)
553+
model_args = dict(depths=(3, 3, 9, 3), dims=(48, 96, 160, 304))
554+
return _create_edgenext('edgenext_small', pretrained=pretrained, **dict(model_args, **kwargs))
555555

556556

557557
@register_model
@@ -561,14 +561,14 @@ def edgenext_base(pretrained=False, **kwargs) -> EdgeNeXt:
561561
# AA=True, Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
562562
# Jetson FPS=xx.xx versus xx.xx for MobileViT_S
563563
# For A100: FPS @ BS=1: xxx.xx & @ BS=256: xxxx.xx
564-
model_kwargs = dict(depths=[3, 3, 9, 3], dims=[80, 160, 288, 584], **kwargs)
565-
return _create_edgenext('edgenext_base', pretrained=pretrained, **model_kwargs)
564+
model_args = dict(depths=[3, 3, 9, 3], dims=[80, 160, 288, 584])
565+
return _create_edgenext('edgenext_base', pretrained=pretrained, **dict(model_args, **kwargs))
566566

567567

568568
@register_model
569569
def edgenext_small_rw(pretrained=False, **kwargs) -> EdgeNeXt:
570-
model_kwargs = dict(
570+
model_args = dict(
571571
depths=(3, 3, 9, 3), dims=(48, 96, 192, 384),
572-
downsample_block=True, conv_bias=False, stem_type='overlap', **kwargs)
573-
return _create_edgenext('edgenext_small_rw', pretrained=pretrained, **model_kwargs)
572+
downsample_block=True, conv_bias=False, stem_type='overlap')
573+
return _create_edgenext('edgenext_small_rw', pretrained=pretrained, **dict(model_args, **kwargs))
574574

0 commit comments

Comments
 (0)