Skip to content

Commit a69863a

Browse files
authored
Merge pull request #2156 from huggingface/hiera
WIP Hiera implementation.
2 parents f7aa0a1 + 7a4e987 commit a69863a

40 files changed

+1002
-68
lines changed

tests/test_models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,15 @@
5252
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
5353
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
5454
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
55-
'regnet', 'byobnet', 'byoanet', 'mlp_mixer'
55+
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera',
5656
]
5757

5858
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
5959
NON_STD_FILTERS = [
6060
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
6161
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
6262
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*',
63-
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*'
63+
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*'
6464
]
6565
NUM_NON_STD = len(NON_STD_FILTERS)
6666

@@ -77,7 +77,7 @@
7777
EXCLUDE_FILTERS = ['*enormous*']
7878
NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*']
7979

80-
EXCLUDE_JIT_FILTERS = []
80+
EXCLUDE_JIT_FILTERS = ['hiera_*']
8181

8282
TARGET_FWD_SIZE = MAX_FWD_SIZE = 384
8383
TARGET_BWD_SIZE = 128
@@ -486,7 +486,7 @@ def _create_fx_model(model, train=False):
486486
return fx_model
487487

488488

489-
EXCLUDE_FX_FILTERS = ['vit_gi*']
489+
EXCLUDE_FX_FILTERS = ['vit_gi*', 'hiera*']
490490
# not enough memory to run fx on more models than other tests
491491
if 'GITHUB_ACTIONS' in os.environ:
492492
EXCLUDE_FX_FILTERS += [

timm/layers/classifier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __init__(
108108
self.fc = fc
109109
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
110110

111-
def reset(self, num_classes, pool_type=None):
111+
def reset(self, num_classes: int, pool_type: Optional[str] = None):
112112
if pool_type is not None and pool_type != self.global_pool.pool_type:
113113
self.global_pool, self.fc = create_classifier(
114114
self.in_features,
@@ -180,7 +180,7 @@ def __init__(
180180
self.drop = nn.Dropout(drop_rate)
181181
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
182182

183-
def reset(self, num_classes, pool_type=None):
183+
def reset(self, num_classes: int, pool_type: Optional[str] = None):
184184
if pool_type is not None:
185185
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
186186
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()

timm/layers/create_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def get_norm_layer(norm_layer):
4747
if isinstance(norm_layer, str):
4848
if not norm_layer:
4949
return None
50-
layer_name = norm_layer.replace('_', '')
50+
layer_name = norm_layer.replace('_', '').lower()
5151
norm_layer = _NORM_MAP[layer_name]
5252
else:
5353
norm_layer = norm_layer

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .ghostnet import *
2727
from .hardcorenas import *
2828
from .hgnet import *
29+
from .hiera import *
2930
from .hrnet import *
3031
from .inception_next import *
3132
from .inception_resnet_v2 import *

timm/models/beit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def group_matcher(self, coarse=False):
395395
def get_classifier(self):
396396
return self.head
397397

398-
def reset_classifier(self, num_classes, global_pool=None):
398+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
399399
self.num_classes = num_classes
400400
if global_pool is not None:
401401
self.global_pool = global_pool

timm/models/cait.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def _matcher(name):
331331
def get_classifier(self):
332332
return self.head
333333

334-
def reset_classifier(self, num_classes, global_pool=None):
334+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
335335
self.num_classes = num_classes
336336
if global_pool is not None:
337337
assert global_pool in ('', 'token', 'avg')

timm/models/coat.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
88
Modified from timm/models/vision_transformer.py
99
"""
10-
from functools import partial
11-
from typing import Tuple, List, Union
10+
from typing import List, Optional, Union, Tuple
1211

1312
import torch
1413
import torch.nn as nn
@@ -560,7 +559,7 @@ def group_matcher(self, coarse=False):
560559
def get_classifier(self):
561560
return self.head
562561

563-
def reset_classifier(self, num_classes, global_pool=None):
562+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
564563
self.num_classes = num_classes
565564
if global_pool is not None:
566565
assert global_pool in ('token', 'avg')

timm/models/convit.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
'''These modules are adapted from those of timm, see
2222
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
2323
'''
24-
25-
from functools import partial
24+
from typing import Optional
2625

2726
import torch
2827
import torch.nn as nn
@@ -349,7 +348,7 @@ def set_grad_checkpointing(self, enable=True):
349348
def get_classifier(self):
350349
return self.head
351350

352-
def reset_classifier(self, num_classes, global_pool=None):
351+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
353352
self.num_classes = num_classes
354353
if global_pool is not None:
355354
assert global_pool in ('', 'token', 'avg')

timm/models/convmixer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
""" ConvMixer
22
33
"""
4+
from typing import Optional
5+
46
import torch
57
import torch.nn as nn
68

@@ -75,7 +77,7 @@ def set_grad_checkpointing(self, enable=True):
7577
def get_classifier(self):
7678
return self.head
7779

78-
def reset_classifier(self, num_classes, global_pool=None):
80+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
7981
self.num_classes = num_classes
8082
if global_pool is not None:
8183
self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)

timm/models/convnext.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
# LICENSE file in the root directory of this source tree (Attribution-NonCommercial 4.0 International (CC BY-NC 4.0))
3838
# No code was used directly from ConvNeXt-V2, however the weights are CC BY-NC 4.0 so beware if using commercially.
3939

40-
from collections import OrderedDict
4140
from functools import partial
4241
from typing import Callable, List, Optional, Tuple, Union
4342

0 commit comments

Comments
 (0)