Skip to content

Commit 89dffc5

Browse files
committed
Another small fix for original mambaout models, no classifier nn.Linear when num_classe=0 on init
1 parent fad4538 commit 89dffc5

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

timm/models/mambaout.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def __init__(
151151
self.num_features = in_features
152152
self.pre_logits = nn.Identity()
153153

154-
self.fc = nn.Linear(hidden_size, num_classes, bias=bias)
154+
self.fc = nn.Linear(hidden_size, num_classes, bias=bias) if num_classes > 0 else nn.Identity()
155155
self.head_dropout = nn.Dropout(drop_rate)
156156

157157
def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):

0 commit comments

Comments
 (0)