Skip to content

[FEAT] Model loading refactor #10604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 43 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
e54c540
first draft model loading refactor
SunMarc Jan 17, 2025
645abc9
revert name change
SunMarc Jan 17, 2025
bd81f50
fix bnb
SunMarc Jan 17, 2025
17c1be2
revert name
SunMarc Jan 18, 2025
72b6259
fix dduf
SunMarc Jan 18, 2025
b4e4f3b
fix huanyan
SunMarc Jan 18, 2025
5a00dc6
style
SunMarc Jan 18, 2025
3bcd6cc
Merge branch 'main' into model-loading-refactor
sayakpaul Jan 20, 2025
2f671af
Update src/diffusers/models/model_loading_utils.py
SunMarc Jan 20, 2025
7273a94
suggestions from reviews
SunMarc Jan 20, 2025
00f0bd1
Merge remote-tracking branch 'upstream/model-loading-refactor' into m…
SunMarc Jan 20, 2025
c5da192
Update src/diffusers/models/modeling_utils.py
SunMarc Jan 21, 2025
039eef5
remove safetensors check
SunMarc Jan 21, 2025
21f94a1
Merge remote-tracking branch 'upstream/model-loading-refactor' into m…
SunMarc Jan 21, 2025
337b2fc
fix default value
SunMarc Jan 23, 2025
aedf6af
Merge remote-tracking branch 'upstream/main' into model-loading-refactor
SunMarc Jan 23, 2025
0df7010
more fix from suggestions
SunMarc Jan 23, 2025
d3a7dc8
revert logic for single file
SunMarc Jan 23, 2025
fc4af16
style
SunMarc Jan 23, 2025
18d61bb
Merge remote-tracking branch 'upstream/main' into model-loading-refactor
SunMarc Jan 23, 2025
26228eb
Merge branch 'main' into model-loading-refactor
sayakpaul Jan 27, 2025
592c878
typing + fix couple of issues
SunMarc Jan 27, 2025
31c7d95
improve speed
SunMarc Feb 4, 2025
1634362
Merge remote-tracking branch 'upstream/model-loading-refactor' into m…
SunMarc Feb 4, 2025
37d59f6
Update src/diffusers/models/modeling_utils.py
SunMarc Feb 14, 2025
b48fedc
Merge remote-tracking branch 'upstream/main' into model-loading-refactor
SunMarc Feb 14, 2025
2753abe
fp8 dtype
SunMarc Feb 14, 2025
abd3a91
add tests
SunMarc Feb 14, 2025
d1c4a61
rename resolved_archive_file to resolved_model_file
SunMarc Feb 14, 2025
bcbd493
format
SunMarc Feb 14, 2025
0b1e9f5
map_location default cpu
SunMarc Feb 17, 2025
a078342
add utility function
SunMarc Feb 17, 2025
52c7104
switch to smaller model + test inference
SunMarc Feb 17, 2025
28211a4
Apply suggestions from code review
SunMarc Feb 17, 2025
e6a8093
rm comment
SunMarc Feb 17, 2025
69dda9a
add log
SunMarc Feb 18, 2025
68a6211
Apply suggestions from code review
SunMarc Feb 18, 2025
c2a72d3
add decorator
SunMarc Feb 18, 2025
712f6b8
Merge remote-tracking branch 'upstream/model-loading-refactor' into m…
SunMarc Feb 18, 2025
e00d1c4
cosine sim instead
SunMarc Feb 18, 2025
65aec7f
fix use_keep_in_fp32_modules
SunMarc Feb 18, 2025
f1138d3
comm
SunMarc Feb 18, 2025
176d30f
Merge branch 'main' into model-loading-refactor
sayakpaul Feb 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@


if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate import dispatch_model, init_empty_weights

from ..models.modeling_utils import load_model_dict_into_meta

Expand Down Expand Up @@ -366,19 +366,23 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
keep_in_fp32_modules=keep_in_fp32_modules,
)

device_map = None
if is_accelerate_available():
param_device = torch.device(device) if device else torch.device("cpu")
named_buffers = model.named_buffers()
unexpected_keys = load_model_dict_into_meta(
empty_state_dict = model.state_dict()
unexpected_keys = [
param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
]
device_map = {"": param_device}
load_model_dict_into_meta(
model,
diffusers_format_checkpoint,
dtype=torch_dtype,
device=param_device,
device_map=device_map,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
named_buffers=named_buffers,
unexpected_keys=unexpected_keys,
)

else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)

Expand All @@ -400,4 +404,8 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =

model.eval()

if device_map is not None:
device_map_kwargs = {"device_map": device_map}
dispatch_model(model, **device_map_kwargs)

return model
24 changes: 3 additions & 21 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1593,18 +1593,9 @@ def create_diffusers_clip_model_from_ldm(
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DN6 would you like to run the slow single-file tests to ensure we're not breaking anything here (if not already)?


if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)

if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
model.load_state_dict(diffusers_format_checkpoint, strict=False)

if torch_dtype is not None:
model.to(torch_dtype)
Expand Down Expand Up @@ -2061,16 +2052,7 @@ def create_diffusers_t5_model_from_checkpoint(
diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)

if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)

load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
else:
model.load_state_dict(diffusers_format_checkpoint)

Expand Down
Loading
Loading