Skip to content

Commit 751b0bb

Browse files
committed
Add global_pool (--gp) arg changes to allow passing 'fast' easily for train/validate to avoid channels_last issue with AdaptiveAvgPool
1 parent 9c297ec commit 751b0bb

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

timm/models/factory.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,17 @@ def create_model(
3939
kwargs.pop('bn_momentum', None)
4040
kwargs.pop('bn_eps', None)
4141

42-
# Parameters that aren't supported by all models should default to None in command line args,
43-
# remove them if they are present and not set so that non-supporting models don't break.
44-
if kwargs.get('drop_block_rate', None) is None:
45-
kwargs.pop('drop_block_rate', None)
46-
4742
# handle backwards compat with drop_connect -> drop_path change
4843
drop_connect_rate = kwargs.pop('drop_connect_rate', None)
4944
if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None:
5045
print("WARNING: 'drop_connect' as an argument is deprecated, please use 'drop_path'."
5146
" Setting drop_path to %f." % drop_connect_rate)
5247
kwargs['drop_path_rate'] = drop_connect_rate
5348

54-
if kwargs.get('drop_path_rate', None) is None:
55-
kwargs.pop('drop_path_rate', None)
49+
# Parameters that aren't supported by all models or are intended to only override model defaults if set
50+
# should default to None in command line args/cfg. Remove them if they are present and not set so that
51+
# non-supporting models don't break and default args remain in effect.
52+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
5653

5754
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
5855
if is_model(model_name):

train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@
7474
help='prevent resume of optimizer state when resuming model')
7575
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
7676
help='number of label classes (default: 1000)')
77-
parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
78-
help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")')
77+
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
78+
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
7979
parser.add_argument('--img-size', type=int, default=None, metavar='N',
8080
help='Image patch size (default: None => model default)')
8181
parser.add_argument('--crop-pct', default=None, type=float,

validate.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@
6464
help='Number classes in dataset')
6565
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
6666
help='path to class to idx mapping file (default: "")')
67+
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
68+
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
6769
parser.add_argument('--log-freq', default=10, type=int,
6870
metavar='N', help='batch logging frequency (default: 10)')
6971
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
@@ -127,6 +129,7 @@ def validate(args):
127129
pretrained=args.pretrained,
128130
num_classes=args.num_classes,
129131
in_chans=3,
132+
global_pool=args.gp,
130133
scriptable=args.torchscript)
131134

132135
if args.checkpoint:

0 commit comments

Comments
 (0)