Skip to content

Commit 77e2e0c

Browse files
committed
Add new auto-augmentation Tensorflow EfficientNet weights, incl B6 and B7 models. Validation scores still pending but looking good.
1 parent 857f330 commit 77e2e0c

File tree

1 file changed

+79
-32
lines changed

1 file changed

+79
-32
lines changed

timm/models/gen_efficientnet.py

Lines changed: 79 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -84,24 +84,34 @@ def _cfg(url='', **kwargs):
8484
url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
8585
'efficientnet_b5': _cfg(
8686
url='', input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
87+
'efficientnet_b6': _cfg(
88+
url='', input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
89+
'efficientnet_b7': _cfg(
90+
url='', input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
8791
'tf_efficientnet_b0': _cfg(
88-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0-0af12548.pth',
92+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
8993
input_size=(3, 224, 224)),
9094
'tf_efficientnet_b1': _cfg(
91-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1-5c1377c4.pth',
95+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth',
9296
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
9397
'tf_efficientnet_b2': _cfg(
94-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2-e393ef04.pth',
98+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth',
9599
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
96100
'tf_efficientnet_b3': _cfg(
97-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3-e3bd6955.pth',
101+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth',
98102
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
99103
'tf_efficientnet_b4': _cfg(
100-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4-74ee3bed.pth',
104+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth',
101105
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
102106
'tf_efficientnet_b5': _cfg(
103-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5-c6949ce9.pth',
107+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_aa-99018a74.pth',
104108
input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
109+
'tf_efficientnet_b6': _cfg(
110+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth',
111+
input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
112+
'tf_efficientnet_b7': _cfg(
113+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_aa-076e3472.pth',
114+
input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
105115
'mixnet_s': _cfg(url=''),
106116
'mixnet_m': _cfg(
107117
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth'),
@@ -763,8 +773,6 @@ def _gen_mnasnet_a1(channel_multiplier, num_classes=1000, **kwargs):
763773
num_classes=num_classes,
764774
stem_size=32,
765775
channel_multiplier=channel_multiplier,
766-
channel_divisor=8,
767-
channel_min=None,
768776
bn_args=_resolve_bn_args(kwargs),
769777
**kwargs
770778
)
@@ -801,8 +809,6 @@ def _gen_mnasnet_b1(channel_multiplier, num_classes=1000, **kwargs):
801809
num_classes=num_classes,
802810
stem_size=32,
803811
channel_multiplier=channel_multiplier,
804-
channel_divisor=8,
805-
channel_min=None,
806812
bn_args=_resolve_bn_args(kwargs),
807813
**kwargs
808814
)
@@ -832,8 +838,6 @@ def _gen_mnasnet_small(channel_multiplier, num_classes=1000, **kwargs):
832838
num_classes=num_classes,
833839
stem_size=8,
834840
channel_multiplier=channel_multiplier,
835-
channel_divisor=8,
836-
channel_min=None,
837841
bn_args=_resolve_bn_args(kwargs),
838842
**kwargs
839843
)
@@ -858,8 +862,6 @@ def _gen_mobilenet_v1(channel_multiplier, num_classes=1000, **kwargs):
858862
stem_size=32,
859863
num_features=1024,
860864
channel_multiplier=channel_multiplier,
861-
channel_divisor=8,
862-
channel_min=None,
863865
bn_args=_resolve_bn_args(kwargs),
864866
act_fn=F.relu6,
865867
head_conv='none',
@@ -887,8 +889,6 @@ def _gen_mobilenet_v2(channel_multiplier, num_classes=1000, **kwargs):
887889
num_classes=num_classes,
888890
stem_size=32,
889891
channel_multiplier=channel_multiplier,
890-
channel_divisor=8,
891-
channel_min=None,
892892
bn_args=_resolve_bn_args(kwargs),
893893
act_fn=F.relu6,
894894
**kwargs
@@ -926,8 +926,6 @@ def _gen_mobilenet_v3(channel_multiplier, num_classes=1000, **kwargs):
926926
num_classes=num_classes,
927927
stem_size=16,
928928
channel_multiplier=channel_multiplier,
929-
channel_divisor=8,
930-
channel_min=None,
931929
bn_args=_resolve_bn_args(kwargs),
932930
act_fn=hard_swish,
933931
se_gate_fn=hard_sigmoid,
@@ -961,8 +959,6 @@ def _gen_chamnet_v1(channel_multiplier, num_classes=1000, **kwargs):
961959
stem_size=32,
962960
num_features=1280, # no idea what this is? try mobile/mnasnet default?
963961
channel_multiplier=channel_multiplier,
964-
channel_divisor=8,
965-
channel_min=None,
966962
bn_args=_resolve_bn_args(kwargs),
967963
**kwargs
968964
)
@@ -992,8 +988,6 @@ def _gen_chamnet_v2(channel_multiplier, num_classes=1000, **kwargs):
992988
stem_size=32,
993989
num_features=1280, # no idea what this is? try mobile/mnasnet default?
994990
channel_multiplier=channel_multiplier,
995-
channel_divisor=8,
996-
channel_min=None,
997991
bn_args=_resolve_bn_args(kwargs),
998992
**kwargs
999993
)
@@ -1024,8 +1018,6 @@ def _gen_fbnetc(channel_multiplier, num_classes=1000, **kwargs):
10241018
stem_size=16,
10251019
num_features=1984, # paper suggests this, but is not 100% clear
10261020
channel_multiplier=channel_multiplier,
1027-
channel_divisor=8,
1028-
channel_min=None,
10291021
bn_args=_resolve_bn_args(kwargs),
10301022
**kwargs
10311023
)
@@ -1061,8 +1053,6 @@ def _gen_spnasnet(channel_multiplier, num_classes=1000, **kwargs):
10611053
num_classes=num_classes,
10621054
stem_size=32,
10631055
channel_multiplier=channel_multiplier,
1064-
channel_divisor=8,
1065-
channel_min=None,
10661056
bn_args=_resolve_bn_args(kwargs),
10671057
**kwargs
10681058
)
@@ -1107,8 +1097,6 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
11071097
num_classes=num_classes,
11081098
stem_size=32,
11091099
channel_multiplier=channel_multiplier,
1110-
channel_divisor=8,
1111-
channel_min=None,
11121100
num_features=num_features,
11131101
bn_args=_resolve_bn_args(kwargs),
11141102
act_fn=swish,
@@ -1144,8 +1132,6 @@ def _gen_mixnet_s(channel_multiplier=1.0, num_classes=1000, **kwargs):
11441132
stem_size=16,
11451133
num_features=1536,
11461134
channel_multiplier=channel_multiplier,
1147-
channel_divisor=8,
1148-
channel_min=None,
11491135
bn_args=_resolve_bn_args(kwargs),
11501136
act_fn=F.relu,
11511137
**kwargs
@@ -1180,8 +1166,6 @@ def _gen_mixnet_m(channel_multiplier=1.0, num_classes=1000, **kwargs):
11801166
stem_size=24,
11811167
num_features=1536,
11821168
channel_multiplier=channel_multiplier,
1183-
channel_divisor=8,
1184-
channel_min=None,
11851169
bn_args=_resolve_bn_args(kwargs),
11861170
act_fn=F.relu,
11871171
**kwargs
@@ -1495,6 +1479,37 @@ def efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
14951479
return model
14961480

14971481

1482+
1483+
@register_model
1484+
def efficientnet_b6(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
1485+
""" EfficientNet-B6 """
1486+
# NOTE for train, drop_rate should be 0.5
1487+
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
1488+
default_cfg = default_cfgs['efficientnet_b6']
1489+
model = _gen_efficientnet(
1490+
channel_multiplier=1.8, depth_multiplier=2.6,
1491+
num_classes=num_classes, in_chans=in_chans, **kwargs)
1492+
model.default_cfg = default_cfg
1493+
if pretrained:
1494+
load_pretrained(model, default_cfg, num_classes, in_chans)
1495+
return model
1496+
1497+
1498+
@register_model
1499+
def efficientnet_b7(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
1500+
""" EfficientNet-B7 """
1501+
# NOTE for train, drop_rate should be 0.5
1502+
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
1503+
default_cfg = default_cfgs['efficientnet_b7']
1504+
model = _gen_efficientnet(
1505+
channel_multiplier=2.0, depth_multiplier=3.1,
1506+
num_classes=num_classes, in_chans=in_chans, **kwargs)
1507+
model.default_cfg = default_cfg
1508+
if pretrained:
1509+
load_pretrained(model, default_cfg, num_classes, in_chans)
1510+
return model
1511+
1512+
14981513
@register_model
14991514
def tf_efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
15001515
""" EfficientNet-B0. Tensorflow compatible variant """
@@ -1585,6 +1600,38 @@ def tf_efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
15851600
return model
15861601

15871602

1603+
@register_model
1604+
def tf_efficientnet_b6(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
1605+
""" EfficientNet-B6. Tensorflow compatible variant """
1606+
# NOTE for train, drop_rate should be 0.5
1607+
default_cfg = default_cfgs['tf_efficientnet_b6']
1608+
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
1609+
kwargs['pad_type'] = 'same'
1610+
model = _gen_efficientnet(
1611+
channel_multiplier=1.8, depth_multiplier=2.6,
1612+
num_classes=num_classes, in_chans=in_chans, **kwargs)
1613+
model.default_cfg = default_cfg
1614+
if pretrained:
1615+
load_pretrained(model, default_cfg, num_classes, in_chans)
1616+
return model
1617+
1618+
1619+
@register_model
1620+
def tf_efficientnet_b7(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
1621+
""" EfficientNet-B7. Tensorflow compatible variant """
1622+
# NOTE for train, drop_rate should be 0.5
1623+
default_cfg = default_cfgs['tf_efficientnet_b7']
1624+
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
1625+
kwargs['pad_type'] = 'same'
1626+
model = _gen_efficientnet(
1627+
channel_multiplier=2.0, depth_multiplier=3.1,
1628+
num_classes=num_classes, in_chans=in_chans, **kwargs)
1629+
model.default_cfg = default_cfg
1630+
if pretrained:
1631+
load_pretrained(model, default_cfg, num_classes, in_chans)
1632+
return model
1633+
1634+
15881635
@register_model
15891636
def mixnet_s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
15901637
"""Creates a MixNet Small model.

0 commit comments

Comments
 (0)