diff --git a/timm/models/_builder.py b/timm/models/_builder.py index d51db363a3..8a66f69f42 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -1,6 +1,7 @@ import dataclasses import logging import os +import inspect from copy import deepcopy from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union @@ -303,6 +304,19 @@ def _filter_kwargs(kwargs: Dict[str, Any], names: List[str]) -> None: for n in names: kwargs.pop(n, None) +def _ignore_kwargs(func, kwargs): + """ Filter kwargs to those that func accepts. + """ + sig = inspect.signature(func) + if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()): + return kwargs + filter_keys = [p.name for p in sig.parameters.values() if p.kind in (p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY)] + filtered_kwargs = {k: v for k, v in kwargs.items() if k in filter_keys} + ignored_keys = set(kwargs.keys()) - set(filtered_kwargs.keys()) + if ignored_keys: + _logger.warning( + f'Ignored attempt to pass arguments ({", ".join(ignored_keys)}) to function {func}.') + return filtered_kwargs def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter) -> None: """ Update the default_cfg and kwargs before passing to model @@ -441,6 +455,7 @@ def build_model_with_cfg( feature_cfg['feature_cls'] = kwargs.pop('feature_cls') # Instantiate the model + kwargs = _ignore_kwargs(model_cls.__init__, kwargs) if model_cfg is None: model = model_cls(**kwargs) else: