Skip to content

Commit b95e335

Browse files
committed
use meta device
1 parent 82e8677 commit b95e335

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

timm/models/_builder.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
import logging
33
import os
44
from copy import deepcopy
5-
from typing import Any, Callable, Dict, List, Optional, Tuple
5+
from typing import Any, Callable, Dict, Optional, Tuple
6+
from contextlib import nullcontext
67

8+
import torch
79
from torch import nn as nn
810
from torch.hub import load_state_dict_from_url
911

@@ -411,10 +413,13 @@ def build_model_with_cfg(
411413
feature_cfg['feature_cls'] = kwargs.pop('feature_cls')
412414

413415
# Instantiate the model
414-
if model_cfg is None:
415-
model = model_cls(**kwargs)
416-
else:
417-
model = model_cls(cfg=model_cfg, **kwargs)
416+
with torch.device("meta") if pretrained else nullcontext():
417+
if model_cfg is None:
418+
model = model_cls(**kwargs)
419+
else:
420+
model = model_cls(cfg=model_cfg, **kwargs)
421+
if pretrained:
422+
model.to_empty(device="cpu")
418423
model.pretrained_cfg = pretrained_cfg
419424
model.default_cfg = model.pretrained_cfg # alias for backwards compat
420425

timm/models/vision_transformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,8 @@ def __init__(
539539
self.patch_drop = nn.Identity()
540540
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
541541

542-
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
542+
with torch.device("cpu"):
543+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
543544
self.blocks = nn.Sequential(*[
544545
block_fn(
545546
dim=embed_dim,

0 commit comments

Comments
 (0)