Skip to content

Commit ba15ca4

Browse files
committed
Add ported EfficientNet-L2, B0-B7 NoisyStudent weights from TF TPU
1 parent d0eb59e commit ba15ca4

File tree

2 files changed

+201
-22
lines changed

2 files changed

+201
-22
lines changed

README.md

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
## What's New
44

5+
### Feb 12, 2020
6+
* Add EfficientNet-L2 and B0-B7 NoisyStudent weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet)
7+
58
### Feb 6, 2020
69
* Add RandAugment trained EfficientNet-ES (EdgeTPU-Small) weights with 78.1 top-1. Trained by [Andrew Lavin](https://github.com/andravin) (see Training section for hparams)
710

@@ -98,15 +101,16 @@ Included models:
98101
* DPN (from [myself](https://github.com/rwightman/pytorch-dpn-pretrained))
99102
* DPN-68, DPN-68b, DPN-92, DPN-98, DPN-131, DPN-107
100103
* EfficientNet (from my standalone [GenEfficientNet](https://github.com/rwightman/gen-efficientnet-pytorch)) - A generic model that implements many of the efficient models that utilize similar DepthwiseSeparable and InvertedResidual blocks
101-
* EfficientNet AdvProp (B0-B8) (https://arxiv.org/abs/1911.09665) -- TF weights ported
102-
* EfficientNet (B0-B7) (https://arxiv.org/abs/1905.11946) -- TF weights ported, B0-B2 finetuned PyTorch
103-
* EfficientNet-EdgeTPU (S, M, L) (https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html) --TF weights ported
104-
* MixNet (https://arxiv.org/abs/1907.09595) -- TF weights ported, PyTorch finetuned (S, M, L) or trained models (XL)
105-
* MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626) -- trained in PyTorch
104+
* EfficientNet NoisyStudent (B0-B7, L2) (https://arxiv.org/abs/1911.04252)
105+
* EfficientNet AdvProp (B0-B8) (https://arxiv.org/abs/1911.09665)
106+
* EfficientNet (B0-B7) (https://arxiv.org/abs/1905.11946)
107+
* EfficientNet-EdgeTPU (S, M, L) (https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html)
108+
* MixNet (https://arxiv.org/abs/1907.09595)
109+
* MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626)
106110
* MobileNet-V2 (https://arxiv.org/abs/1801.04381)
107-
* FBNet-C (https://arxiv.org/abs/1812.03443) -- trained in PyTorch
108-
* Single-Path NAS (https://arxiv.org/abs/1904.02877) -- pixel1 variant
109-
* MobileNet-V3 (https://arxiv.org/abs/1905.02244) -- pretrained PyTorch model, official TF weights ported
111+
* FBNet-C (https://arxiv.org/abs/1812.03443)
112+
* Single-Path NAS (https://arxiv.org/abs/1904.02877)
113+
* MobileNet-V3 (https://arxiv.org/abs/1905.02244)
110114
* HRNet
111115
* code from https://github.com/HRNet/HRNet-Image-Classification, paper https://arxiv.org/abs/1908.07919
112116
* SelecSLS
@@ -178,30 +182,48 @@ For the models below, the model code and weight porting from Tensorflow or MXNet
178182

179183
| Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Image Size |
180184
|---|---|---|---|---|---|
185+
| tf_efficientnet_l2_ns *tfp | 88.352 (11.648) | 98.652 (1.348) | 480 | bicubic | 800 |
186+
| tf_efficientnet_l2_ns | TBD | TBD | 480 | bicubic | 800 |
187+
| tf_efficientnet_l2_ns *tfp | 88.172 (11.828) | 98.566 (1.434) | 480 | bicubic | 475 |
188+
| tf_efficientnet_l2_ns | TBD | TBD | 480 | bicubic | 475 |
189+
| tf_efficientnet_b7_ns *tfp | 86.844 (13.156) | 98.084 (1.916) | 66.35 | bicubic | 600 |
190+
| tf_efficientnet_b7_ns | 86.840 (13.160) | 98.094 (1.906) | 66.35 | bicubic | 600 |
191+
| tf_efficientnet_b6_ns | 86.452 (13.548) | 97.882 (2.118) | 43.04 | bicubic | 528 |
192+
| tf_efficientnet_b6_ns *tfp | 86.444 (13.556) | 97.880 (2.120) | 43.04 | bicubic | 528 |
193+
| tf_efficientnet_b5_ns *tfp | 86.064 (13.936) | 97.746 (2.254) | 30.39 | bicubic | 456 |
194+
| tf_efficientnet_b5_ns | 86.088 (13.912) | 97.752 (2.248) | 30.39 | bicubic | 456 |
181195
| tf_efficientnet_b8_ap *tfp | 85.436 (14.564) | 97.272 (2.728) | 87.4 | bicubic | 672 |
182-
| tf_efficientnet_b8 *tfp | 85.384 (14.616) | 97.394 (2.606) | 87.4 | bicubic | 672 |
183-
| tf_efficientnet_b8 | 85.37 (14.63) | 97.39 (2.61) | 87.4 | bicubic | 672 |
196+
| tf_efficientnet_b8 *tfp | 85.384 (14.616) | 97.394 (2.606) | 87.4 | bicubic | 672 |
197+
| tf_efficientnet_b8 | 85.37 (14.63) | 97.39 (2.61) | 87.4 | bicubic | 672 |
184198
| tf_efficientnet_b8_ap | 85.368 (14.632) | 97.294 (2.706) | 87.4 | bicubic | 672 |
199+
| tf_efficientnet_b4_ns *tfp | 85.298 (14.702) | 97.504 (2.496) | 19.34 | bicubic | 380 |
200+
| tf_efficientnet_b4_ns | 85.162 (14.838) | 97.470 (2.530) | 19.34 | bicubic | 380 |
185201
| tf_efficientnet_b7_ap *tfp | 85.154 (14.846) | 97.244 (2.756) | 66.35 | bicubic | 600 |
186202
| tf_efficientnet_b7_ap | 85.118 (14.882) | 97.252 (2.748) | 66.35 | bicubic | 600 |
187-
| tf_efficientnet_b7 *tfp | 84.940 (15.060) | 97.214 (2.786) | 66.35 | bicubic | 600 |
188-
| tf_efficientnet_b7 | 84.932 (15.068) | 97.208 (2.792) | 66.35 | bicubic | 600 |
203+
| tf_efficientnet_b7 *tfp | 84.940 (15.060) | 97.214 (2.786) | 66.35 | bicubic | 600 |
204+
| tf_efficientnet_b7 | 84.932 (15.068) | 97.208 (2.792) | 66.35 | bicubic | 600 |
189205
| tf_efficientnet_b6_ap | 84.786 (15.214) | 97.138 (2.862) | 43.04 | bicubic | 528 |
190206
| tf_efficientnet_b6_ap *tfp | 84.760 (15.240) | 97.124 (2.876) | 43.04 | bicubic | 528 |
191207
| tf_efficientnet_b5_ap *tfp | 84.276 (15.724) | 96.932 (3.068) | 30.39 | bicubic | 456 |
192208
| tf_efficientnet_b5_ap | 84.254 (15.746) | 96.976 (3.024) | 30.39 | bicubic | 456 |
193-
| tf_efficientnet_b6 *tfp | 84.140 (15.860) | 96.852 (3.148) | 43.04 | bicubic | 528 |
194-
| tf_efficientnet_b6 | 84.110 (15.890) | 96.886 (3.114) | 43.04 | bicubic | 528 |
195-
| tf_efficientnet_b5 *tfp | 83.822 (16.178) | 96.756 (3.244) | 30.39 | bicubic | 456 |
196-
| tf_efficientnet_b5 | 83.812 (16.188) | 96.748 (3.252) | 30.39 | bicubic | 456 |
209+
| tf_efficientnet_b6 *tfp | 84.140 (15.860) | 96.852 (3.148) | 43.04 | bicubic | 528 |
210+
| tf_efficientnet_b6 | 84.110 (15.890) | 96.886 (3.114) | 43.04 | bicubic | 528 |
211+
| tf_efficientnet_b3_ns *tfp | 84.054 (15.946) | 96.918 (3.082) | 12.23 | bicubic | 300 |
212+
| tf_efficientnet_b3_ns | 84.048 (15.952) | 96.910 (3.090) | 12.23 | bicubic | 300 |
213+
| tf_efficientnet_b5 *tfp | 83.822 (16.178) | 96.756 (3.244) | 30.39 | bicubic | 456 |
214+
| tf_efficientnet_b5 | 83.812 (16.188) | 96.748 (3.252) | 30.39 | bicubic | 456 |
197215
| tf_efficientnet_b4_ap *tfp | 83.278 (16.722) | 96.376 (3.624) | 19.34 | bicubic | 380 |
198216
| tf_efficientnet_b4_ap | 83.248 (16.752) | 96.388 (3.612) | 19.34 | bicubic | 380 |
199-
| tf_efficientnet_b4 | 83.022 (16.978) | 96.300 (3.700) | 19.34 | bicubic | 380 |
200-
| tf_efficientnet_b4 *tfp | 82.948 (17.052) | 96.308 (3.692) | 19.34 | bicubic | 380 |
217+
| tf_efficientnet_b4 | 83.022 (16.978) | 96.300 (3.700) | 19.34 | bicubic | 380 |
218+
| tf_efficientnet_b4 *tfp | 82.948 (17.052) | 96.308 (3.692) | 19.34 | bicubic | 380 |
219+
| tf_efficientnet_b2_ns *tfp | 82.436 (17.564) | 96.268 (3.732) | 9.11 | bicubic | 260 |
220+
| tf_efficientnet_b2_ns | 82.380 (17.620) | 96.248 (3.752) | 9.11 | bicubic | 260 |
201221
| tf_efficientnet_b3_ap *tfp | 81.882 (18.118) | 95.662 (4.338) | 12.23 | bicubic | 300 |
202222
| tf_efficientnet_b3_ap | 81.828 (18.172) | 95.624 (4.376) | 12.23 | bicubic | 300 |
203-
| tf_efficientnet_b3 | 81.636 (18.364) | 95.718 (4.282) | 12.23 | bicubic | 300 |
204-
| tf_efficientnet_b3 *tfp | 81.576 (18.424) | 95.662 (4.338) | 12.23 | bicubic | 300 |
223+
| tf_efficientnet_b3 | 81.636 (18.364) | 95.718 (4.282) | 12.23 | bicubic | 300 |
224+
| tf_efficientnet_b3 *tfp | 81.576 (18.424) | 95.662 (4.338) | 12.23 | bicubic | 300 |
225+
| tf_efficientnet_b1_ns *tfp | 81.514 (18.486) | 95.776 (4.224) | 7.79 | bicubic | 240 |
226+
| tf_efficientnet_b1_ns | 81.388 (18.612) | 95.738 (4.262) | 7.79 | bicubic | 240 |
205227
| gluon_senet154 | 81.224 (18.776) | 95.356 (4.644) | 115.09 | bicubic | 224 |
206228
| gluon_resnet152_v1s | 81.012 (18.988) | 95.416 (4.584) | 60.32 | bicubic | 224 |
207229
| gluon_seresnext101_32x4d | 80.902 (19.098) | 95.294 (4.706) | 48.96 | bicubic | 224 |
@@ -233,10 +255,12 @@ For the models below, the model code and weight porting from Tensorflow or MXNet
233255
| tf_efficientnet_em *tfp | 78.958 (21.042) | 94.458 (5.542) | 6.90 | bicubic | 240 |
234256
| tf_mixnet_l *tfp | 78.846 (21.154) | 94.212 (5.788) | 7.33 | bilinear | 224 |
235257
| tf_efficientnet_b1 | 78.826 (21.174) | 94.198 (5.802) | 7.79 | bicubic | 240 |
258+
| tf_efficientnet_b0_ns *tfp | 78.806 (21.194) | 94.496 (5.504) | 5.29 | bicubic | 224 |
236259
| gluon_inception_v3 | 78.804 (21.196) | 94.380 (5.620) | 27.16M | bicubic | 299 |
237260
| tf_mixnet_l | 78.770 (21.230) | 94.004 (5.996) | 7.33 | bicubic | 224 |
238261
| tf_efficientnet_em | 78.742 (21.258) | 94.332 (5.668) | 6.90 | bicubic | 240 |
239262
| gluon_resnet50_v1s | 78.712 (21.288) | 94.242 (5.758) | 25.68 | bicubic | 224 |
263+
| tf_efficientnet_b0_ns | 78.658 (21.342) | 94.376 (5.624) | 5.29 | bicubic | 224 |
240264
| tf_efficientnet_cc_b0_8e *tfp | 78.314 (21.686) | 93.790 (6.210) | 24.0 | bicubic | 224 |
241265
| gluon_resnet50_v1c | 78.010 (21.990) | 93.988 (6.012) | 25.58 | bicubic | 224 |
242266
| tf_efficientnet_cc_b0_8e | 77.908 (22.092) | 93.656 (6.344) | 24.0 | bicubic | 224 |

timm/models/efficientnet.py

Lines changed: 157 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
33
An implementation of EfficienNet that covers variety of related models with efficient architectures:
44
5-
* EfficientNet (B0-B8 + Tensorflow pretrained AutoAug/RandAug/AdvProp weight ports)
5+
* EfficientNet (B0-B8, L2 + Tensorflow pretrained AutoAug/RandAug/AdvProp/NoisyStudent weight ports)
66
- EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946
77
- CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971
88
- Adversarial Examples Improve Image Recognition - https://arxiv.org/abs/1911.09665
9+
- Self-training with Noisy Student improves ImageNet classification - https://arxiv.org/abs/1911.04252
910
1011
* MixNet (Small, Medium, and Large)
1112
- MixConv: Mixed Depthwise Convolutional Kernels - https://arxiv.org/abs/1907.09595
@@ -91,6 +92,8 @@ def _cfg(url='', **kwargs):
9192
url='', input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
9293
'efficientnet_b8': _cfg(
9394
url='', input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
95+
'efficientnet_l2': _cfg(
96+
url='', input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961),
9497
'efficientnet_es': _cfg(
9598
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth'),
9699
'efficientnet_em': _cfg(
@@ -162,6 +165,36 @@ def _cfg(url='', **kwargs):
162165
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth',
163166
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
164167
input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
168+
'tf_efficientnet_b0_ns': _cfg(
169+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth',
170+
input_size=(3, 224, 224)),
171+
'tf_efficientnet_b1_ns': _cfg(
172+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth',
173+
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
174+
'tf_efficientnet_b2_ns': _cfg(
175+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth',
176+
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
177+
'tf_efficientnet_b3_ns': _cfg(
178+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth',
179+
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
180+
'tf_efficientnet_b4_ns': _cfg(
181+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth',
182+
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
183+
'tf_efficientnet_b5_ns': _cfg(
184+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth',
185+
input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
186+
'tf_efficientnet_b6_ns': _cfg(
187+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth',
188+
input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
189+
'tf_efficientnet_b7_ns': _cfg(
190+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth',
191+
input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
192+
'tf_efficientnet_l2_ns_475': _cfg(
193+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth',
194+
input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936),
195+
'tf_efficientnet_l2_ns': _cfg(
196+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',
197+
input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961),
165198
'tf_efficientnet_es': _cfg(
166199
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth',
167200
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
@@ -208,7 +241,7 @@ class EfficientNet(nn.Module):
208241
""" (Generic) EfficientNet
209242
210243
A flexible and performant PyTorch implementation of efficient network architectures, including:
211-
* EfficientNet B0-B8
244+
* EfficientNet B0-B8, L2
212245
* EfficientNet-EdgeTPU
213246
* EfficientNet-CondConv
214247
* MixNet S, M, L, XL
@@ -586,6 +619,7 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
586619
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
587620
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
588621
'efficientnet-b8': (2.2, 3.6, 672, 0.5),
622+
'efficientnet-l2': (4.3, 5.3, 800, 0.5),
589623
590624
Args:
591625
channel_multiplier: multiplier to number of channels per layer
@@ -928,6 +962,24 @@ def efficientnet_b7(pretrained=False, **kwargs):
928962
return model
929963

930964

965+
@register_model
966+
def efficientnet_b8(pretrained=False, **kwargs):
967+
""" EfficientNet-B8 """
968+
# NOTE for train, drop_rate should be 0.5, drop_connect_rate should be 0.2
969+
model = _gen_efficientnet(
970+
'efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs)
971+
return model
972+
973+
974+
@register_model
975+
def efficientnet_l2(pretrained=False, **kwargs):
976+
""" EfficientNet-L2."""
977+
# NOTE for train, drop_rate should be 0.5, drop_connect_rate should be 0.2
978+
model = _gen_efficientnet(
979+
'efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
980+
return model
981+
982+
931983
@register_model
932984
def efficientnet_es(pretrained=False, **kwargs):
933985
""" EfficientNet-Edge Small. """
@@ -1166,6 +1218,109 @@ def tf_efficientnet_b8_ap(pretrained=False, **kwargs):
11661218
return model
11671219

11681220

1221+
@register_model
1222+
def tf_efficientnet_b0_ns(pretrained=False, **kwargs):
1223+
""" EfficientNet-B0 NoisyStudent. Tensorflow compatible variant """
1224+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1225+
kwargs['pad_type'] = 'same'
1226+
model = _gen_efficientnet(
1227+
'tf_efficientnet_b0_ns', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
1228+
return model
1229+
1230+
1231+
@register_model
1232+
def tf_efficientnet_b1_ns(pretrained=False, **kwargs):
1233+
""" EfficientNet-B1 NoisyStudent. Tensorflow compatible variant """
1234+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1235+
kwargs['pad_type'] = 'same'
1236+
model = _gen_efficientnet(
1237+
'tf_efficientnet_b1_ns', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
1238+
return model
1239+
1240+
1241+
@register_model
1242+
def tf_efficientnet_b2_ns(pretrained=False, **kwargs):
1243+
""" EfficientNet-B2 NoisyStudent. Tensorflow compatible variant """
1244+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1245+
kwargs['pad_type'] = 'same'
1246+
model = _gen_efficientnet(
1247+
'tf_efficientnet_b2_ns', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
1248+
return model
1249+
1250+
1251+
@register_model
1252+
def tf_efficientnet_b3_ns(pretrained=False, **kwargs):
1253+
""" EfficientNet-B3 NoisyStudent. Tensorflow compatible variant """
1254+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1255+
kwargs['pad_type'] = 'same'
1256+
model = _gen_efficientnet(
1257+
'tf_efficientnet_b3_ns', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
1258+
return model
1259+
1260+
1261+
@register_model
1262+
def tf_efficientnet_b4_ns(pretrained=False, **kwargs):
1263+
""" EfficientNet-B4 NoisyStudent. Tensorflow compatible variant """
1264+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1265+
kwargs['pad_type'] = 'same'
1266+
model = _gen_efficientnet(
1267+
'tf_efficientnet_b4_ns', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
1268+
return model
1269+
1270+
1271+
@register_model
1272+
def tf_efficientnet_b5_ns(pretrained=False, **kwargs):
1273+
""" EfficientNet-B5 NoisyStudent. Tensorflow compatible variant """
1274+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1275+
kwargs['pad_type'] = 'same'
1276+
model = _gen_efficientnet(
1277+
'tf_efficientnet_b5_ns', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
1278+
return model
1279+
1280+
1281+
@register_model
1282+
def tf_efficientnet_b6_ns(pretrained=False, **kwargs):
1283+
""" EfficientNet-B6 NoisyStudent. Tensorflow compatible variant """
1284+
# NOTE for train, drop_rate should be 0.5
1285+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1286+
kwargs['pad_type'] = 'same'
1287+
model = _gen_efficientnet(
1288+
'tf_efficientnet_b6_ns', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
1289+
return model
1290+
1291+
1292+
@register_model
1293+
def tf_efficientnet_b7_ns(pretrained=False, **kwargs):
1294+
""" EfficientNet-B7 NoisyStudent. Tensorflow compatible variant """
1295+
# NOTE for train, drop_rate should be 0.5
1296+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1297+
kwargs['pad_type'] = 'same'
1298+
model = _gen_efficientnet(
1299+
'tf_efficientnet_b7_ns', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
1300+
return model
1301+
1302+
1303+
@register_model
1304+
def tf_efficientnet_l2_ns_475(pretrained=False, **kwargs):
1305+
""" EfficientNet-L2 NoisyStudent @ 475x475. Tensorflow compatible variant """
1306+
# NOTE for train, drop_rate should be 0.5
1307+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1308+
kwargs['pad_type'] = 'same'
1309+
model = _gen_efficientnet(
1310+
'tf_efficientnet_l2_ns_475', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
1311+
return model
1312+
1313+
1314+
@register_model
1315+
def tf_efficientnet_l2_ns(pretrained=False, **kwargs):
1316+
""" EfficientNet-L2 NoisyStudent. Tensorflow compatible variant """
1317+
# NOTE for train, drop_rate should be 0.5
1318+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1319+
kwargs['pad_type'] = 'same'
1320+
model = _gen_efficientnet(
1321+
'tf_efficientnet_l2_ns', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
1322+
return model
1323+
11691324

11701325
@register_model
11711326
def tf_efficientnet_es(pretrained=False, **kwargs):

0 commit comments

Comments
 (0)