Skip to content

Commit 6a621b5

Browse files
committed
Pass block_fn and mlp_layer through from NaFlexVit cfg, fixes a few models
1 parent f9b3d7e commit 6a621b5

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

timm/models/naflexvit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -844,8 +844,8 @@ def __init__(
844844
norm_layer = get_norm_layer(cfg.norm_layer) or LayerNorm
845845
embed_norm_layer = get_norm_layer(cfg.embed_norm_layer)
846846
act_layer = get_act_layer(cfg.act_layer) or nn.GELU
847-
block_fn = Block # TODO: Support configurable block_fn via string lookup
848-
mlp_layer = Mlp # TODO: Support configurable mlp_layer via string lookup
847+
block_fn = cfg.block_fn or Block # TODO: Support configurable block_fn via string lookup
848+
mlp_layer = cfg.mlp_layer or Mlp # TODO: Support configurable mlp_layer via string lookup
849849

850850
# Store instance variables
851851
self.num_classes = num_classes

0 commit comments

Comments
 (0)