Skip to content

Commit c117328

Browse files
committed
Ran into tracing issues, rejigging fx tracing helpers, move leaf/autowrap to timm.layers, use torch.fx.wrap when possible
1 parent b01f23e commit c117328

13 files changed

+265
-151
lines changed

timm/layers/__init__.py

Lines changed: 86 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,40 @@
1+
from ._fx import (
2+
create_feature_extractor,
3+
get_graph_node_names,
4+
register_notrace_function,
5+
register_notrace_module,
6+
is_notrace_module,
7+
is_notrace_function,
8+
get_notrace_modules,
9+
get_notrace_functions,
10+
)
111
from .activations import *
2-
from .adaptive_avgmax_pool import \
3-
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
12+
from .adaptive_avgmax_pool import (
13+
adaptive_avgmax_pool2d,
14+
select_adaptive_pool2d,
15+
AdaptiveAvgMaxPool2d,
16+
SelectAdaptivePool2d,
17+
)
418
from .attention import Attention, AttentionRope, maybe_add_mask
519
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
620
from .attention_pool import AttentionPoolLatent
721
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
822
from .blur_pool import BlurPool2d, create_aa
923
from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead
1024
from .cond_conv2d import CondConv2d, get_condconv_initializer
11-
from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \
12-
set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn, \
13-
set_reentrant_ckpt, use_reentrant_ckpt
25+
from .config import (
26+
is_exportable,
27+
is_scriptable,
28+
is_no_jit,
29+
use_fused_attn,
30+
set_exportable,
31+
set_scriptable,
32+
set_no_jit,
33+
set_layer_config,
34+
set_fused_attn,
35+
set_reentrant_ckpt,
36+
use_reentrant_ckpt,
37+
)
1438
from .conv2d_same import Conv2dSame, conv2d_same
1539
from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
1640
from .create_act import create_act_layer, get_act_layer, get_act_fn
@@ -20,8 +44,17 @@
2044
from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
2145
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
2246
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
23-
from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
24-
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
47+
from .evo_norm import (
48+
EvoNorm2dB0,
49+
EvoNorm2dB1,
50+
EvoNorm2dB2,
51+
EvoNorm2dS0,
52+
EvoNorm2dS0a,
53+
EvoNorm2dS1,
54+
EvoNorm2dS1a,
55+
EvoNorm2dS2,
56+
EvoNorm2dS2a,
57+
)
2558
from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
2659
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
2760
from .format import Format, get_channel_dim, get_spatial_dim, nchw_to, nhwc_to
@@ -37,19 +70,50 @@
3770
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
3871
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
3972
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d
40-
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
41-
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
73+
from .norm_act import (
74+
BatchNormAct2d,
75+
GroupNormAct,
76+
GroupNorm1Act,
77+
LayerNormAct,
78+
LayerNormAct2d,
79+
SyncBatchNormAct,
80+
convert_sync_batchnorm,
81+
FrozenBatchNormAct2d,
82+
freeze_batch_norm_2d,
83+
unfreeze_batch_norm_2d,
84+
)
4285
from .padding import get_padding, get_same_padding, pad_same
4386
from .patch_dropout import PatchDropout
4487
from .patch_embed import PatchEmbed, PatchEmbedWithSize, PatchEmbedInterpolator, resample_patch_embed
4588
from .pool1d import global_pool_nlc
4689
from .pool2d_same import AvgPool2dSame, create_pool2d
4790
from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
48-
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \
49-
resize_rel_pos_bias_table, resize_rel_pos_bias_table_simple, resize_rel_pos_bias_table_levit
50-
from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \
51-
build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \
52-
FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat, RotaryEmbeddingMixed
91+
from .pos_embed_rel import (
92+
RelPosMlp,
93+
RelPosBias,
94+
RelPosBiasTf,
95+
gen_relative_position_index,
96+
gen_relative_log_coords,
97+
resize_rel_pos_bias_table,
98+
resize_rel_pos_bias_table_simple,
99+
resize_rel_pos_bias_table_levit,
100+
)
101+
from .pos_embed_sincos import (
102+
pixel_freq_bands,
103+
freq_bands,
104+
build_sincos2d_pos_embed,
105+
build_fourier_pos_embed,
106+
build_rotary_pos_embed,
107+
apply_rot_embed,
108+
apply_rot_embed_cat,
109+
apply_rot_embed_list,
110+
apply_keep_indices_nlc,
111+
FourierEmbed,
112+
RotaryEmbedding,
113+
RotaryEmbeddingCat,
114+
RotaryEmbeddingMixed,
115+
get_mixed_freqs,
116+
)
53117
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
54118
from .selective_kernel import SelectiveKernel
55119
from .separable_conv import SeparableConv2d, SeparableConvNormAct
@@ -60,5 +124,11 @@
60124
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
61125
from .trace_utils import _assert, _float_to_int
62126
from .typing import LayerType, PadType
63-
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_, \
64-
init_weight_jax, init_weight_vit
127+
from .weight_init import (
128+
trunc_normal_,
129+
trunc_normal_tf_,
130+
variance_scaling_,
131+
lecun_normal_,
132+
init_weight_jax,
133+
init_weight_vit,
134+
)

timm/layers/_fx.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from typing import Callable, Dict, List, Optional, Union, Tuple, Type
2+
3+
import torch
4+
from torch import nn
5+
6+
try:
7+
# NOTE we wrap torchvision fns to use timm leaf / no trace definitions
8+
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
9+
from torchvision.models.feature_extraction import get_graph_node_names as _get_graph_node_names
10+
has_fx_feature_extraction = True
11+
except ImportError:
12+
has_fx_feature_extraction = False
13+
14+
15+
__all__ = [
16+
'register_notrace_module',
17+
'is_notrace_module',
18+
'get_notrace_modules',
19+
'register_notrace_function',
20+
'is_notrace_function',
21+
'get_notrace_functions',
22+
'create_feature_extractor',
23+
'get_graph_node_names',
24+
]
25+
26+
# modules to treat as leafs when tracing
27+
_leaf_modules = set()
28+
29+
30+
def register_notrace_module(module: Type[nn.Module]):
31+
"""
32+
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
33+
"""
34+
_leaf_modules.add(module)
35+
return module
36+
37+
38+
def is_notrace_module(module: Type[nn.Module]):
39+
return module in _leaf_modules
40+
41+
42+
def get_notrace_modules():
43+
return list(_leaf_modules)
44+
45+
46+
# Functions we want to autowrap (treat them as leaves)
47+
_autowrap_functions = set()
48+
49+
try:
50+
# pass through to torch.fx.wrap when possible, works in some cases our old mechanism doesn't
51+
register_notrace_function = torch.fx.wrap # exists in modern PyTorch
52+
except AttributeError:
53+
# old Torch
54+
def register_notrace_function(name_or_fn):
55+
if callable(name_or_fn):
56+
_autowrap_functions.add(name_or_fn)
57+
return name_or_fn
58+
_autowrap_functions.add(name_or_fn)
59+
return name_or_fn
60+
61+
62+
def is_notrace_function(func: Callable):
63+
return func in _autowrap_functions
64+
65+
66+
def get_notrace_functions():
67+
return list(_autowrap_functions)
68+
69+
70+
def get_graph_node_names(model: nn.Module) -> Tuple[List[str], List[str]]:
71+
return _get_graph_node_names(
72+
model,
73+
tracer_kwargs={
74+
'leaf_modules': list(_leaf_modules),
75+
'autowrap_functions': list(_autowrap_functions)
76+
}
77+
)
78+
79+
80+
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
81+
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
82+
return _create_feature_extractor(
83+
model, return_nodes,
84+
tracer_kwargs={
85+
'leaf_modules': list(_leaf_modules),
86+
'autowrap_functions': list(_autowrap_functions)
87+
}
88+
)

timm/layers/attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
from torch import nn as nn
55
from torch.nn import functional as F
66

7+
from ._fx import register_notrace_function
78
from .config import use_fused_attn
89
from .pos_embed_sincos import apply_rot_embed_cat
910

1011

12+
@register_notrace_function
1113
def maybe_add_mask(scores: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
1214
return scores if attn_mask is None else scores + attn_mask
1315

timm/layers/cond_conv2d.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch import nn as nn
1313
from torch.nn import functional as F
1414

15+
from ._fx import register_notrace_module
1516
from .helpers import to_2tuple
1617
from .conv2d_same import conv2d_same
1718
from .padding import get_padding_value
@@ -30,6 +31,7 @@ def condconv_initializer(weight):
3031
return condconv_initializer
3132

3233

34+
@register_notrace_module
3335
class CondConv2d(nn.Module):
3436
""" Conditionally Parameterized Convolution
3537
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py

timm/layers/conv2d_same.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch.nn.functional as F
88
from typing import Tuple, Optional
99

10+
from ._fx import register_notrace_module
1011
from .config import is_exportable, is_scriptable
1112
from .padding import pad_same, pad_same_arg, get_padding_value
1213

@@ -27,6 +28,7 @@ def conv2d_same(
2728
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
2829

2930

31+
@register_notrace_module
3032
class Conv2dSame(nn.Conv2d):
3133
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
3234
"""

timm/layers/inplace_abn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ def inplace_abn(x, weight, bias, running_mean, running_var,
1515
def inplace_abn_sync(**kwargs):
1616
inplace_abn(**kwargs)
1717

18+
from ._fx import register_notrace_module
1819

20+
21+
@register_notrace_module
1922
class InplaceAbn(nn.Module):
2023
"""Activated Batch Normalization
2124

timm/layers/non_local_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torch import nn
99
from torch.nn import functional as F
1010

11+
from ._fx import register_notrace_module
1112
from .conv_bn_act import ConvNormAct
1213
from .helpers import make_divisible
1314
from .trace_utils import _assert
@@ -69,6 +70,7 @@ def reset_parameters(self):
6970
nn.init.constant_(m.bias, 0)
7071

7172

73+
@register_notrace_module
7274
class BilinearAttnTransform(nn.Module):
7375

7476
def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):

timm/layers/norm_act.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch.nn import functional as F
2020
from torchvision.ops.misc import FrozenBatchNorm2d
2121

22+
from ._fx import register_notrace_module
2223
from .create_act import create_act_layer
2324
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, rms_norm2d, fast_rms_norm2d
2425
from .norm import RmsNorm, RmsNorm2d
@@ -39,6 +40,7 @@ def _create_act(act_layer, act_kwargs=None, inplace=False, apply_act=True):
3940
return nn.Identity() if act is None else act
4041

4142

43+
@register_notrace_module
4244
class BatchNormAct2d(nn.BatchNorm2d):
4345
"""BatchNorm + Activation
4446
@@ -134,6 +136,7 @@ def forward(self, x):
134136
return x
135137

136138

139+
@register_notrace_module
137140
class SyncBatchNormAct(nn.SyncBatchNorm):
138141
# Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254)
139142
# This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers
@@ -191,6 +194,7 @@ def convert_sync_batchnorm(module, process_group=None):
191194
return module_output
192195

193196

197+
@register_notrace_module
194198
class FrozenBatchNormAct2d(torch.nn.Module):
195199
"""
196200
BatchNormAct2d where the batch statistics and the affine parameters are fixed

timm/layers/pool2d_same.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,25 @@
77
import torch.nn.functional as F
88
from typing import List, Tuple, Optional
99

10+
from ._fx import register_notrace_module
1011
from .helpers import to_2tuple
1112
from .padding import pad_same, get_padding_value
1213

1314

14-
def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
15-
ceil_mode: bool = False, count_include_pad: bool = True):
15+
def avg_pool2d_same(
16+
x: torch.Tensor,
17+
kernel_size: List[int],
18+
stride: List[int],
19+
padding: List[int] = (0, 0),
20+
ceil_mode: bool = False,
21+
count_include_pad: bool = True,
22+
):
1623
# FIXME how to deal with count_include_pad vs not for external padding?
1724
x = pad_same(x, kernel_size, stride)
1825
return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
1926

2027

28+
@register_notrace_module
2129
class AvgPool2dSame(nn.AvgPool2d):
2230
""" Tensorflow like 'SAME' wrapper for 2D average pooling
2331
"""
@@ -33,12 +41,18 @@ def forward(self, x):
3341

3442

3543
def max_pool2d_same(
36-
x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
37-
dilation: List[int] = (1, 1), ceil_mode: bool = False):
44+
x: torch.Tensor,
45+
kernel_size: List[int],
46+
stride: List[int],
47+
padding: List[int] = (0, 0),
48+
dilation: List[int] = (1, 1),
49+
ceil_mode: bool = False,
50+
):
3851
x = pad_same(x, kernel_size, stride, value=-float('inf'))
3952
return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode)
4053

4154

55+
@register_notrace_module
4256
class MaxPool2dSame(nn.MaxPool2d):
4357
""" Tensorflow like 'SAME' wrapper for 2D max pooling
4458
"""

0 commit comments

Comments
 (0)