Skip to content

Commit 1c13ef7

Browse files
committed
Add default norm_eps=1e-5 for convnext_xxlarge, improve kwarg merging for all convnext models
1 parent 450b74a commit 1c13ef7

File tree

1 file changed

+51
-61
lines changed

1 file changed

+51
-61
lines changed

timm/models/convnext.py

Lines changed: 51 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -789,200 +789,190 @@ def _cfgv2(url='', **kwargs):
789789
@register_model
790790
def convnext_atto(pretrained=False, **kwargs):
791791
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
792-
model_args = dict(
793-
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, **kwargs)
794-
model = _create_convnext('convnext_atto', pretrained=pretrained, **model_args)
792+
model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True)
793+
model = _create_convnext('convnext_atto', pretrained=pretrained, **dict(model_args, **kwargs))
795794
return model
796795

797796

798797
@register_model
799798
def convnext_atto_ols(pretrained=False, **kwargs):
800799
# timm femto variant with overlapping 3x3 conv stem, wider than non-ols femto above, current param count 3.7M
801-
model_args = dict(
802-
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, stem_type='overlap_tiered', **kwargs)
803-
model = _create_convnext('convnext_atto_ols', pretrained=pretrained, **model_args)
800+
model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, stem_type='overlap_tiered')
801+
model = _create_convnext('convnext_atto_ols', pretrained=pretrained, **dict(model_args, **kwargs))
804802
return model
805803

806804

807805
@register_model
808806
def convnext_femto(pretrained=False, **kwargs):
809807
# timm femto variant
810-
model_args = dict(
811-
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, **kwargs)
812-
model = _create_convnext('convnext_femto', pretrained=pretrained, **model_args)
808+
model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True)
809+
model = _create_convnext('convnext_femto', pretrained=pretrained, **dict(model_args, **kwargs))
813810
return model
814811

815812

816813
@register_model
817814
def convnext_femto_ols(pretrained=False, **kwargs):
818815
# timm femto variant
819-
model_args = dict(
820-
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, stem_type='overlap_tiered', **kwargs)
821-
model = _create_convnext('convnext_femto_ols', pretrained=pretrained, **model_args)
816+
model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, stem_type='overlap_tiered')
817+
model = _create_convnext('convnext_femto_ols', pretrained=pretrained, **dict(model_args, **kwargs))
822818
return model
823819

824820

825821
@register_model
826822
def convnext_pico(pretrained=False, **kwargs):
827823
# timm pico variant
828-
model_args = dict(
829-
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, **kwargs)
830-
model = _create_convnext('convnext_pico', pretrained=pretrained, **model_args)
824+
model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True)
825+
model = _create_convnext('convnext_pico', pretrained=pretrained, **dict(model_args, **kwargs))
831826
return model
832827

833828

834829
@register_model
835830
def convnext_pico_ols(pretrained=False, **kwargs):
836831
# timm nano variant with overlapping 3x3 conv stem
837-
model_args = dict(
838-
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, stem_type='overlap_tiered', **kwargs)
839-
model = _create_convnext('convnext_pico_ols', pretrained=pretrained, **model_args)
832+
model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, stem_type='overlap_tiered')
833+
model = _create_convnext('convnext_pico_ols', pretrained=pretrained, **dict(model_args, **kwargs))
840834
return model
841835

842836

843837
@register_model
844838
def convnext_nano(pretrained=False, **kwargs):
845839
# timm nano variant with standard stem and head
846-
model_args = dict(
847-
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, **kwargs)
848-
model = _create_convnext('convnext_nano', pretrained=pretrained, **model_args)
840+
model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True)
841+
model = _create_convnext('convnext_nano', pretrained=pretrained, **dict(model_args, **kwargs))
849842
return model
850843

851844

852845
@register_model
853846
def convnext_nano_ols(pretrained=False, **kwargs):
854847
# experimental nano variant with overlapping conv stem
855-
model_args = dict(
856-
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, stem_type='overlap', **kwargs)
857-
model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **model_args)
848+
model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, stem_type='overlap')
849+
model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **dict(model_args, **kwargs))
858850
return model
859851

860852

861853
@register_model
862854
def convnext_tiny_hnf(pretrained=False, **kwargs):
863855
# experimental tiny variant with norm before pooling in head (head norm first)
864-
model_args = dict(
865-
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs)
866-
model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args)
856+
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True)
857+
model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **dict(model_args, **kwargs))
867858
return model
868859

869860

870861
@register_model
871862
def convnext_tiny(pretrained=False, **kwargs):
872-
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs)
873-
model = _create_convnext('convnext_tiny', pretrained=pretrained, **model_args)
863+
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768))
864+
model = _create_convnext('convnext_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
874865
return model
875866

876867

877868
@register_model
878869
def convnext_small(pretrained=False, **kwargs):
879-
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
880-
model = _create_convnext('convnext_small', pretrained=pretrained, **model_args)
870+
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768])
871+
model = _create_convnext('convnext_small', pretrained=pretrained, **dict(model_args, **kwargs))
881872
return model
882873

883874

884875
@register_model
885876
def convnext_base(pretrained=False, **kwargs):
886-
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
887-
model = _create_convnext('convnext_base', pretrained=pretrained, **model_args)
877+
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
878+
model = _create_convnext('convnext_base', pretrained=pretrained, **dict(model_args, **kwargs))
888879
return model
889880

890881

891882
@register_model
892883
def convnext_large(pretrained=False, **kwargs):
893-
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
894-
model = _create_convnext('convnext_large', pretrained=pretrained, **model_args)
884+
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536])
885+
model = _create_convnext('convnext_large', pretrained=pretrained, **dict(model_args, **kwargs))
895886
return model
896887

897888

898889
@register_model
899890
def convnext_large_mlp(pretrained=False, **kwargs):
900-
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], head_hidden_size=1536, **kwargs)
901-
model = _create_convnext('convnext_large_mlp', pretrained=pretrained, **model_args)
891+
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], head_hidden_size=1536)
892+
model = _create_convnext('convnext_large_mlp', pretrained=pretrained, **dict(model_args, **kwargs))
902893
return model
903894

904895

905896
@register_model
906897
def convnext_xlarge(pretrained=False, **kwargs):
907-
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
908-
model = _create_convnext('convnext_xlarge', pretrained=pretrained, **model_args)
898+
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])
899+
model = _create_convnext('convnext_xlarge', pretrained=pretrained, **dict(model_args, **kwargs))
909900
return model
910901

911902

912903
@register_model
913904
def convnext_xxlarge(pretrained=False, **kwargs):
914-
model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], **kwargs)
915-
model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **model_args)
905+
model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], norm_eps=kwargs.pop('norm_eps', 1e-5))
906+
model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **dict(model_args, **kwargs))
916907
return model
917908

918909

919910
@register_model
920911
def convnextv2_atto(pretrained=False, **kwargs):
921912
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
922913
model_args = dict(
923-
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs)
924-
model = _create_convnext('convnextv2_atto', pretrained=pretrained, **model_args)
914+
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), use_grn=True, ls_init_value=None, conv_mlp=True)
915+
model = _create_convnext('convnextv2_atto', pretrained=pretrained, **dict(model_args, **kwargs))
925916
return model
926917

927918

928919
@register_model
929920
def convnextv2_femto(pretrained=False, **kwargs):
930921
# timm femto variant
931922
model_args = dict(
932-
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs)
933-
model = _create_convnext('convnextv2_femto', pretrained=pretrained, **model_args)
923+
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), use_grn=True, ls_init_value=None, conv_mlp=True)
924+
model = _create_convnext('convnextv2_femto', pretrained=pretrained, **dict(model_args, **kwargs))
934925
return model
935926

936927

937928
@register_model
938929
def convnextv2_pico(pretrained=False, **kwargs):
939930
# timm pico variant
940931
model_args = dict(
941-
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs)
942-
model = _create_convnext('convnextv2_pico', pretrained=pretrained, **model_args)
932+
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), use_grn=True, ls_init_value=None, conv_mlp=True)
933+
model = _create_convnext('convnextv2_pico', pretrained=pretrained, **dict(model_args, **kwargs))
943934
return model
944935

945936

946937
@register_model
947938
def convnextv2_nano(pretrained=False, **kwargs):
948939
# timm nano variant with standard stem and head
949940
model_args = dict(
950-
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs)
951-
model = _create_convnext('convnextv2_nano', pretrained=pretrained, **model_args)
941+
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), use_grn=True, ls_init_value=None, conv_mlp=True)
942+
model = _create_convnext('convnextv2_nano', pretrained=pretrained, **dict(model_args, **kwargs))
952943
return model
953944

954945

955946
@register_model
956947
def convnextv2_tiny(pretrained=False, **kwargs):
957-
model_args = dict(
958-
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), use_grn=True, ls_init_value=None, **kwargs)
959-
model = _create_convnext('convnextv2_tiny', pretrained=pretrained, **model_args)
948+
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), use_grn=True, ls_init_value=None)
949+
model = _create_convnext('convnextv2_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
960950
return model
961951

962952

963953
@register_model
964954
def convnextv2_small(pretrained=False, **kwargs):
965-
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], use_grn=True, ls_init_value=None, **kwargs)
966-
model = _create_convnext('convnextv2_small', pretrained=pretrained, **model_args)
955+
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], use_grn=True, ls_init_value=None)
956+
model = _create_convnext('convnextv2_small', pretrained=pretrained, **dict(model_args, **kwargs))
967957
return model
968958

969959

970960
@register_model
971961
def convnextv2_base(pretrained=False, **kwargs):
972-
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], use_grn=True, ls_init_value=None, **kwargs)
973-
model = _create_convnext('convnextv2_base', pretrained=pretrained, **model_args)
962+
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], use_grn=True, ls_init_value=None)
963+
model = _create_convnext('convnextv2_base', pretrained=pretrained, **dict(model_args, **kwargs))
974964
return model
975965

976966

977967
@register_model
978968
def convnextv2_large(pretrained=False, **kwargs):
979-
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], use_grn=True, ls_init_value=None, **kwargs)
980-
model = _create_convnext('convnextv2_large', pretrained=pretrained, **model_args)
969+
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], use_grn=True, ls_init_value=None)
970+
model = _create_convnext('convnextv2_large', pretrained=pretrained, **dict(model_args, **kwargs))
981971
return model
982972

983973

984974
@register_model
985975
def convnextv2_huge(pretrained=False, **kwargs):
986-
model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None, **kwargs)
987-
model = _create_convnext('convnextv2_huge', pretrained=pretrained, **model_args)
988-
return model
976+
model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None)
977+
model = _create_convnext('convnextv2_huge', pretrained=pretrained, **dict(model_args, **kwargs))
978+
return model

0 commit comments

Comments
 (0)