Skip to content

Add kwargs for timm.create_model in TimmWrapper #38860

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 20, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/transformers/models/timm_wrapper/configuration_timm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

"""Configuration for TimmWrapper models"""

from typing import Any, Dict
from typing import Any, Dict, Optional

from ...configuration_utils import PretrainedConfig
from ...utils import is_timm_available, logging, requires_backends
Expand Down Expand Up @@ -45,6 +45,9 @@ class TimmWrapperConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
do_pooling (`bool`, *optional*, defaults to `True`):
Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not.
model_init_kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments to pass to the `timm.create_model` function. e.g. `model_init_kwargs={"depth": 3}`
for `timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k` to create a model with 3 blocks. Defaults to `None`.

Example:
```python
Expand All @@ -60,9 +63,16 @@ class TimmWrapperConfig(PretrainedConfig):

model_type = "timm_wrapper"

def __init__(self, initializer_range: float = 0.02, do_pooling: bool = True, **kwargs):
def __init__(
self,
initializer_range: float = 0.02,
do_pooling: bool = True,
model_init_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
):
self.initializer_range = initializer_range
self.do_pooling = do_pooling
self.model_init_kwargs = model_init_kwargs
super().__init__(**kwargs)

@classmethod
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ class TimmWrapperModel(TimmWrapperPreTrainedModel):
def __init__(self, config: TimmWrapperConfig):
super().__init__(config)
# using num_classes=0 to avoid creating classification head
self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=0)
model_init_kwargs = config.model_init_kwargs or {}
self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=0, **model_init_kwargs)
self.post_init()

@auto_docstring
Expand Down Expand Up @@ -233,7 +234,10 @@ def __init__(self, config: TimmWrapperConfig):
"or use `TimmWrapperModel` for feature extraction."
)

self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=config.num_labels)
model_init_kwargs = config.model_init_kwargs or {}
self.timm_model = timm.create_model(
config.architecture, pretrained=False, num_classes=config.num_labels, **model_init_kwargs
)
self.num_labels = config.num_labels
self.post_init()

Expand Down
11 changes: 11 additions & 0 deletions tests/models/timm_wrapper/test_modeling_timm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,17 @@ def test_timm_config_labels(self):
self.assertEqual(config.id2label, restored_config.id2label)
self.assertEqual(config.label2id, restored_config.label2id)

def test_model_init_kwargs(self):
config = TimmWrapperConfig.from_pretrained(
"timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k",
model_init_kwargs={"depth": 3},
)
model = TimmWrapperModel(config)
self.assertEqual(len(model.timm_model.blocks), 3)

cls_model = TimmWrapperForImageClassification(config)
self.assertEqual(len(cls_model.timm_model.blocks), 3)


# We will verify our results on an image of cute cats
def prepare_img():
Expand Down