From d70c4811791cbaaaed74530e05cb0479426f0e14 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 30 Mar 2025 17:05:25 +0800 Subject: [PATCH 1/2] override device kwargs of base nn classes --- timm/models/_builder.py | 49 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 44 insertions(+), 5 deletions(-) diff --git a/timm/models/_builder.py b/timm/models/_builder.py index 482d370a94..cf80b455b1 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -1,10 +1,12 @@ +from contextlib import contextmanager, nullcontext import dataclasses import logging import os from copy import deepcopy from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union +import torch from torch import nn as nn from torch.hub import load_state_dict_from_url @@ -360,6 +362,27 @@ def resolve_pretrained_cfg( return pretrained_cfg +@contextmanager +def make_meta_init(*classes): + def create_new_init(cls): + old_init = cls.__init__ + def new_init(self, *args, **kwargs): + kwargs.update(device="meta") + old_init(self, *args, **kwargs) + return new_init + + original_dict = dict() + for cls in classes: + original_dict[cls] = cls.__init__ + cls.__init__ = create_new_init(cls) + + yield + + # restore original __init__() + for cls, old_init in original_dict.items(): + cls.__init__ = old_init + + def build_model_with_cfg( model_cls: Callable, variant: str, @@ -419,11 +442,27 @@ def build_model_with_cfg( if 'feature_cls' in kwargs: feature_cfg['feature_cls'] = kwargs.pop('feature_cls') + # use meta-device init to speed up loading pretrained weights. + # when num_classes is changed, we can't use meta device init since we need + # the original __init__() to initialize head from scratch. + num_classes = 0 if features else kwargs.get("num_classes", pretrained_cfg["num_classes"]) + use_meta_init = ( + pretrained + and (num_classes == 0 or num_classes == pretrained_cfg["num_classes"]) + ) + # Instantiate the model - if model_cfg is None: - model = model_cls(**kwargs) - else: - model = model_cls(cfg=model_cfg, **kwargs) + base_classes = [nn.Linear, nn.Conv2d, nn.BatchNorm2d, nn.LayerNorm] + with make_meta_init(*base_classes) if use_meta_init else nullcontext(): + if model_cfg is None: + model = model_cls(**kwargs) + else: + model = model_cls(cfg=model_cfg, **kwargs) + + # convert meta-device tensors to concrete tensors + device = kwargs.get("device", torch.get_default_device()) + model._apply(lambda t: (torch.empty_like(t, device=device) if t.is_meta else t)) + model.pretrained_cfg = pretrained_cfg model.default_cfg = model.pretrained_cfg # alias for backwards compat From c3445e9d758ad3c5f3ec66c6b722cf7b78c9d325 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 30 Mar 2025 17:16:25 +0800 Subject: [PATCH 2/2] support pytorch<2 --- timm/models/_builder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/_builder.py b/timm/models/_builder.py index cf80b455b1..1b2f653396 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -460,7 +460,8 @@ def build_model_with_cfg( model = model_cls(cfg=model_cfg, **kwargs) # convert meta-device tensors to concrete tensors - device = kwargs.get("device", torch.get_default_device()) + default_device = torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu" + device = kwargs.get("device", default_device) model._apply(lambda t: (torch.empty_like(t, device=device) if t.is_meta else t)) model.pretrained_cfg = pretrained_cfg