Skip to content

Commit f5929e0

Browse files
SunMarcsayakpaulyiyixuxua-r-r-o-w
authored
[FEAT] Model loading refactor (#10604)
* first draft model loading refactor * revert name change * fix bnb * revert name * fix dduf * fix huanyan * style * Update src/diffusers/models/model_loading_utils.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * suggestions from reviews * Update src/diffusers/models/modeling_utils.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * remove safetensors check * fix default value * more fix from suggestions * revert logic for single file * style * typing + fix couple of issues * improve speed * Update src/diffusers/models/modeling_utils.py Co-authored-by: Aryan <aryan@huggingface.co> * fp8 dtype * add tests * rename resolved_archive_file to resolved_model_file * format * map_location default cpu * add utility function * switch to smaller model + test inference * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * rm comment * add log * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * add decorator * cosine sim instead * fix use_keep_in_fp32_modules * comm --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: Aryan <aryan@huggingface.co>
1 parent 6fe05b9 commit f5929e0

File tree

12 files changed

+844
-515
lines changed

12 files changed

+844
-515
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252

5353

5454
if is_accelerate_available():
55-
from accelerate import init_empty_weights
55+
from accelerate import dispatch_model, init_empty_weights
5656

5757
from ..models.modeling_utils import load_model_dict_into_meta
5858

@@ -366,19 +366,23 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
366366
keep_in_fp32_modules=keep_in_fp32_modules,
367367
)
368368

369+
device_map = None
369370
if is_accelerate_available():
370371
param_device = torch.device(device) if device else torch.device("cpu")
371-
named_buffers = model.named_buffers()
372-
unexpected_keys = load_model_dict_into_meta(
372+
empty_state_dict = model.state_dict()
373+
unexpected_keys = [
374+
param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
375+
]
376+
device_map = {"": param_device}
377+
load_model_dict_into_meta(
373378
model,
374379
diffusers_format_checkpoint,
375380
dtype=torch_dtype,
376-
device=param_device,
381+
device_map=device_map,
377382
hf_quantizer=hf_quantizer,
378383
keep_in_fp32_modules=keep_in_fp32_modules,
379-
named_buffers=named_buffers,
384+
unexpected_keys=unexpected_keys,
380385
)
381-
382386
else:
383387
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
384388

@@ -400,4 +404,8 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
400404

401405
model.eval()
402406

407+
if device_map is not None:
408+
device_map_kwargs = {"device_map": device_map}
409+
dispatch_model(model, **device_map_kwargs)
410+
403411
return model

src/diffusers/loaders/single_file_utils.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,18 +1593,9 @@ def create_diffusers_clip_model_from_ldm(
15931593
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
15941594

15951595
if is_accelerate_available():
1596-
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
1596+
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
15971597
else:
1598-
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
1599-
1600-
if model._keys_to_ignore_on_load_unexpected is not None:
1601-
for pat in model._keys_to_ignore_on_load_unexpected:
1602-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1603-
1604-
if len(unexpected_keys) > 0:
1605-
logger.warning(
1606-
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1607-
)
1598+
model.load_state_dict(diffusers_format_checkpoint, strict=False)
16081599

16091600
if torch_dtype is not None:
16101601
model.to(torch_dtype)
@@ -2061,16 +2052,7 @@ def create_diffusers_t5_model_from_checkpoint(
20612052
diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)
20622053

20632054
if is_accelerate_available():
2064-
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
2065-
if model._keys_to_ignore_on_load_unexpected is not None:
2066-
for pat in model._keys_to_ignore_on_load_unexpected:
2067-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
2068-
2069-
if len(unexpected_keys) > 0:
2070-
logger.warning(
2071-
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
2072-
)
2073-
2055+
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
20742056
else:
20752057
model.load_state_dict(diffusers_format_checkpoint)
20762058

0 commit comments

Comments
 (0)