@@ -84,24 +84,34 @@ def _cfg(url='', **kwargs):
84
84
url = '' , input_size = (3 , 380 , 380 ), pool_size = (12 , 12 ), crop_pct = 0.922 ),
85
85
'efficientnet_b5' : _cfg (
86
86
url = '' , input_size = (3 , 456 , 456 ), pool_size = (15 , 15 ), crop_pct = 0.934 ),
87
+ 'efficientnet_b6' : _cfg (
88
+ url = '' , input_size = (3 , 528 , 528 ), pool_size = (17 , 17 ), crop_pct = 0.942 ),
89
+ 'efficientnet_b7' : _cfg (
90
+ url = '' , input_size = (3 , 600 , 600 ), pool_size = (19 , 19 ), crop_pct = 0.949 ),
87
91
'tf_efficientnet_b0' : _cfg (
88
- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0-0af12548 .pth' ,
92
+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33 .pth' ,
89
93
input_size = (3 , 224 , 224 )),
90
94
'tf_efficientnet_b1' : _cfg (
91
- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1-5c1377c4 .pth' ,
95
+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0 .pth' ,
92
96
input_size = (3 , 240 , 240 ), pool_size = (8 , 8 ), crop_pct = 0.882 ),
93
97
'tf_efficientnet_b2' : _cfg (
94
- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2-e393ef04 .pth' ,
98
+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97 .pth' ,
95
99
input_size = (3 , 260 , 260 ), pool_size = (9 , 9 ), crop_pct = 0.890 ),
96
100
'tf_efficientnet_b3' : _cfg (
97
- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3-e3bd6955 .pth' ,
101
+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e .pth' ,
98
102
input_size = (3 , 300 , 300 ), pool_size = (10 , 10 ), crop_pct = 0.904 ),
99
103
'tf_efficientnet_b4' : _cfg (
100
- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4-74ee3bed .pth' ,
104
+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c .pth' ,
101
105
input_size = (3 , 380 , 380 ), pool_size = (12 , 12 ), crop_pct = 0.922 ),
102
106
'tf_efficientnet_b5' : _cfg (
103
- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5-c6949ce9 .pth' ,
107
+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_aa-99018a74 .pth' ,
104
108
input_size = (3 , 456 , 456 ), pool_size = (15 , 15 ), crop_pct = 0.934 ),
109
+ 'tf_efficientnet_b6' : _cfg (
110
+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth' ,
111
+ input_size = (3 , 528 , 528 ), pool_size = (17 , 17 ), crop_pct = 0.942 ),
112
+ 'tf_efficientnet_b7' : _cfg (
113
+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_aa-076e3472.pth' ,
114
+ input_size = (3 , 600 , 600 ), pool_size = (19 , 19 ), crop_pct = 0.949 ),
105
115
'mixnet_s' : _cfg (url = '' ),
106
116
'mixnet_m' : _cfg (
107
117
url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth' ),
@@ -763,8 +773,6 @@ def _gen_mnasnet_a1(channel_multiplier, num_classes=1000, **kwargs):
763
773
num_classes = num_classes ,
764
774
stem_size = 32 ,
765
775
channel_multiplier = channel_multiplier ,
766
- channel_divisor = 8 ,
767
- channel_min = None ,
768
776
bn_args = _resolve_bn_args (kwargs ),
769
777
** kwargs
770
778
)
@@ -801,8 +809,6 @@ def _gen_mnasnet_b1(channel_multiplier, num_classes=1000, **kwargs):
801
809
num_classes = num_classes ,
802
810
stem_size = 32 ,
803
811
channel_multiplier = channel_multiplier ,
804
- channel_divisor = 8 ,
805
- channel_min = None ,
806
812
bn_args = _resolve_bn_args (kwargs ),
807
813
** kwargs
808
814
)
@@ -832,8 +838,6 @@ def _gen_mnasnet_small(channel_multiplier, num_classes=1000, **kwargs):
832
838
num_classes = num_classes ,
833
839
stem_size = 8 ,
834
840
channel_multiplier = channel_multiplier ,
835
- channel_divisor = 8 ,
836
- channel_min = None ,
837
841
bn_args = _resolve_bn_args (kwargs ),
838
842
** kwargs
839
843
)
@@ -858,8 +862,6 @@ def _gen_mobilenet_v1(channel_multiplier, num_classes=1000, **kwargs):
858
862
stem_size = 32 ,
859
863
num_features = 1024 ,
860
864
channel_multiplier = channel_multiplier ,
861
- channel_divisor = 8 ,
862
- channel_min = None ,
863
865
bn_args = _resolve_bn_args (kwargs ),
864
866
act_fn = F .relu6 ,
865
867
head_conv = 'none' ,
@@ -887,8 +889,6 @@ def _gen_mobilenet_v2(channel_multiplier, num_classes=1000, **kwargs):
887
889
num_classes = num_classes ,
888
890
stem_size = 32 ,
889
891
channel_multiplier = channel_multiplier ,
890
- channel_divisor = 8 ,
891
- channel_min = None ,
892
892
bn_args = _resolve_bn_args (kwargs ),
893
893
act_fn = F .relu6 ,
894
894
** kwargs
@@ -926,8 +926,6 @@ def _gen_mobilenet_v3(channel_multiplier, num_classes=1000, **kwargs):
926
926
num_classes = num_classes ,
927
927
stem_size = 16 ,
928
928
channel_multiplier = channel_multiplier ,
929
- channel_divisor = 8 ,
930
- channel_min = None ,
931
929
bn_args = _resolve_bn_args (kwargs ),
932
930
act_fn = hard_swish ,
933
931
se_gate_fn = hard_sigmoid ,
@@ -961,8 +959,6 @@ def _gen_chamnet_v1(channel_multiplier, num_classes=1000, **kwargs):
961
959
stem_size = 32 ,
962
960
num_features = 1280 , # no idea what this is? try mobile/mnasnet default?
963
961
channel_multiplier = channel_multiplier ,
964
- channel_divisor = 8 ,
965
- channel_min = None ,
966
962
bn_args = _resolve_bn_args (kwargs ),
967
963
** kwargs
968
964
)
@@ -992,8 +988,6 @@ def _gen_chamnet_v2(channel_multiplier, num_classes=1000, **kwargs):
992
988
stem_size = 32 ,
993
989
num_features = 1280 , # no idea what this is? try mobile/mnasnet default?
994
990
channel_multiplier = channel_multiplier ,
995
- channel_divisor = 8 ,
996
- channel_min = None ,
997
991
bn_args = _resolve_bn_args (kwargs ),
998
992
** kwargs
999
993
)
@@ -1024,8 +1018,6 @@ def _gen_fbnetc(channel_multiplier, num_classes=1000, **kwargs):
1024
1018
stem_size = 16 ,
1025
1019
num_features = 1984 , # paper suggests this, but is not 100% clear
1026
1020
channel_multiplier = channel_multiplier ,
1027
- channel_divisor = 8 ,
1028
- channel_min = None ,
1029
1021
bn_args = _resolve_bn_args (kwargs ),
1030
1022
** kwargs
1031
1023
)
@@ -1061,8 +1053,6 @@ def _gen_spnasnet(channel_multiplier, num_classes=1000, **kwargs):
1061
1053
num_classes = num_classes ,
1062
1054
stem_size = 32 ,
1063
1055
channel_multiplier = channel_multiplier ,
1064
- channel_divisor = 8 ,
1065
- channel_min = None ,
1066
1056
bn_args = _resolve_bn_args (kwargs ),
1067
1057
** kwargs
1068
1058
)
@@ -1107,8 +1097,6 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
1107
1097
num_classes = num_classes ,
1108
1098
stem_size = 32 ,
1109
1099
channel_multiplier = channel_multiplier ,
1110
- channel_divisor = 8 ,
1111
- channel_min = None ,
1112
1100
num_features = num_features ,
1113
1101
bn_args = _resolve_bn_args (kwargs ),
1114
1102
act_fn = swish ,
@@ -1144,8 +1132,6 @@ def _gen_mixnet_s(channel_multiplier=1.0, num_classes=1000, **kwargs):
1144
1132
stem_size = 16 ,
1145
1133
num_features = 1536 ,
1146
1134
channel_multiplier = channel_multiplier ,
1147
- channel_divisor = 8 ,
1148
- channel_min = None ,
1149
1135
bn_args = _resolve_bn_args (kwargs ),
1150
1136
act_fn = F .relu ,
1151
1137
** kwargs
@@ -1180,8 +1166,6 @@ def _gen_mixnet_m(channel_multiplier=1.0, num_classes=1000, **kwargs):
1180
1166
stem_size = 24 ,
1181
1167
num_features = 1536 ,
1182
1168
channel_multiplier = channel_multiplier ,
1183
- channel_divisor = 8 ,
1184
- channel_min = None ,
1185
1169
bn_args = _resolve_bn_args (kwargs ),
1186
1170
act_fn = F .relu ,
1187
1171
** kwargs
@@ -1495,6 +1479,37 @@ def efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
1495
1479
return model
1496
1480
1497
1481
1482
+
1483
+ @register_model
1484
+ def efficientnet_b6 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
1485
+ """ EfficientNet-B6 """
1486
+ # NOTE for train, drop_rate should be 0.5
1487
+ #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
1488
+ default_cfg = default_cfgs ['efficientnet_b6' ]
1489
+ model = _gen_efficientnet (
1490
+ channel_multiplier = 1.8 , depth_multiplier = 2.6 ,
1491
+ num_classes = num_classes , in_chans = in_chans , ** kwargs )
1492
+ model .default_cfg = default_cfg
1493
+ if pretrained :
1494
+ load_pretrained (model , default_cfg , num_classes , in_chans )
1495
+ return model
1496
+
1497
+
1498
+ @register_model
1499
+ def efficientnet_b7 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
1500
+ """ EfficientNet-B7 """
1501
+ # NOTE for train, drop_rate should be 0.5
1502
+ #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
1503
+ default_cfg = default_cfgs ['efficientnet_b7' ]
1504
+ model = _gen_efficientnet (
1505
+ channel_multiplier = 2.0 , depth_multiplier = 3.1 ,
1506
+ num_classes = num_classes , in_chans = in_chans , ** kwargs )
1507
+ model .default_cfg = default_cfg
1508
+ if pretrained :
1509
+ load_pretrained (model , default_cfg , num_classes , in_chans )
1510
+ return model
1511
+
1512
+
1498
1513
@register_model
1499
1514
def tf_efficientnet_b0 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
1500
1515
""" EfficientNet-B0. Tensorflow compatible variant """
@@ -1585,6 +1600,38 @@ def tf_efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
1585
1600
return model
1586
1601
1587
1602
1603
+ @register_model
1604
+ def tf_efficientnet_b6 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
1605
+ """ EfficientNet-B6. Tensorflow compatible variant """
1606
+ # NOTE for train, drop_rate should be 0.5
1607
+ default_cfg = default_cfgs ['tf_efficientnet_b6' ]
1608
+ kwargs ['bn_eps' ] = _BN_EPS_TF_DEFAULT
1609
+ kwargs ['pad_type' ] = 'same'
1610
+ model = _gen_efficientnet (
1611
+ channel_multiplier = 1.8 , depth_multiplier = 2.6 ,
1612
+ num_classes = num_classes , in_chans = in_chans , ** kwargs )
1613
+ model .default_cfg = default_cfg
1614
+ if pretrained :
1615
+ load_pretrained (model , default_cfg , num_classes , in_chans )
1616
+ return model
1617
+
1618
+
1619
+ @register_model
1620
+ def tf_efficientnet_b7 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
1621
+ """ EfficientNet-B7. Tensorflow compatible variant """
1622
+ # NOTE for train, drop_rate should be 0.5
1623
+ default_cfg = default_cfgs ['tf_efficientnet_b7' ]
1624
+ kwargs ['bn_eps' ] = _BN_EPS_TF_DEFAULT
1625
+ kwargs ['pad_type' ] = 'same'
1626
+ model = _gen_efficientnet (
1627
+ channel_multiplier = 2.0 , depth_multiplier = 3.1 ,
1628
+ num_classes = num_classes , in_chans = in_chans , ** kwargs )
1629
+ model .default_cfg = default_cfg
1630
+ if pretrained :
1631
+ load_pretrained (model , default_cfg , num_classes , in_chans )
1632
+ return model
1633
+
1634
+
1588
1635
@register_model
1589
1636
def mixnet_s (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
1590
1637
"""Creates a MixNet Small model.
0 commit comments