Skip to content

Commit b9a336f

Browse files
committed
Use both torch.fx.wrap and autowrap fns list to avoid some issues
1 parent 74a052d commit b9a336f

File tree

5 files changed

+8
-17
lines changed

5 files changed

+8
-17
lines changed

timm/layers/_fx.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,10 @@ def get_notrace_modules():
4646
# Functions we want to autowrap (treat them as leaves)
4747
_autowrap_functions = set()
4848

49-
try:
50-
# pass through to torch.fx.wrap when possible, works in some cases our old mechanism doesn't
51-
register_notrace_function = torch.fx.wrap # exists in modern PyTorch
52-
except AttributeError:
53-
# old Torch
54-
def register_notrace_function(name_or_fn):
55-
if callable(name_or_fn):
56-
_autowrap_functions.add(name_or_fn)
57-
return name_or_fn
58-
_autowrap_functions.add(name_or_fn)
59-
return name_or_fn
49+
50+
def register_notrace_function(name_or_fn):
51+
_autowrap_functions.add(name_or_fn)
52+
return name_or_fn
6053

6154

6255
def is_notrace_function(func: Callable):

timm/layers/attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .pos_embed_sincos import apply_rot_embed_cat
1010

1111

12+
@torch.fx.wrap
1213
@register_notrace_function
1314
def maybe_add_mask(scores: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
1415
return scores if attn_mask is None else scores + attn_mask

timm/layers/attention_pool2d.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,12 @@
1212
import torch
1313
import torch.nn as nn
1414

15-
from ._fx import register_notrace_function
1615
from .config import use_fused_attn
1716
from .helpers import to_2tuple
1817
from .pos_embed import resample_abs_pos_embed
1918
from .pos_embed_sincos import apply_rot_embed, RotaryEmbedding
2019
from .weight_init import trunc_normal_
2120

22-
# have to register again for some reason
23-
register_notrace_function(resample_abs_pos_embed)
24-
2521

2622
class RotAttentionPool2d(nn.Module):
2723
""" Attention based 2D feature pooling w/ rotary (relative) pos embedding.

timm/layers/pos_embed.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
_logger = logging.getLogger(__name__)
1515

1616

17+
@torch.fx.wrap
1718
@register_notrace_function
1819
def resample_abs_pos_embed(
1920
posemb: torch.Tensor,
@@ -58,6 +59,7 @@ def resample_abs_pos_embed(
5859
return posemb
5960

6061

62+
@torch.fx.wrap
6163
@register_notrace_function
6264
def resample_abs_pos_embed_nhwc(
6365
posemb: torch.Tensor,

timm/layers/pos_embed_sincos.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,7 @@ def init_random_2d_freqs(
512512
return torch.stack([fx, fy], dim=0)
513513

514514

515+
@torch.fx.wrap
515516
@register_notrace_function
516517
def get_mixed_freqs(
517518
freqs: torch.Tensor,
@@ -584,8 +585,6 @@ def __init__(
584585
) # (2, depth, num_heads, head_dim//2)
585586
self.freqs = nn.Parameter(freqs)
586587

587-
588-
589588
def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor:
590589
"""Generate rotary embeddings for the given spatial shape.
591590

0 commit comments

Comments
 (0)