Skip to content

Commit 0d9d98f

Browse files
authored
Fix typos (#9739)
* update * update * update * update * update * update
1 parent 60ffa84 commit 0d9d98f

File tree

2 files changed

+95
-2
lines changed

2 files changed

+95
-2
lines changed

docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,26 @@ image = pipe("a picture of a cat holding a sign that says hello world").images[0
313313
image.save('sd3-single-file-t5-fp8.png')
314314
```
315315

316+
### Loading the single file checkpoint for the Stable Diffusion 3.5 Transformer Model
317+
318+
```python
319+
import torch
320+
from diffusers import SD3Transformer2DModel, StableDiffusion3Pipeline
321+
322+
transformer = SD3Transformer2DModel.from_single_file(
323+
"https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo/blob/main/sd3.5_large.safetensors",
324+
torch_dtype=torch.bfloat16,
325+
)
326+
pipe = StableDiffusion3Pipeline.from_pretrained(
327+
"stabilityai/stable-diffusion-3.5-large",
328+
transformer=transformer,
329+
torch_dtype=torch.bfloat16,
330+
)
331+
pipe.enable_model_cpu_offload()
332+
image = pipe("a cat holding a sign that says hello world").images[0]
333+
image.save("sd35.png")
334+
```
335+
316336
## StableDiffusion3Pipeline
317337

318338
[[autodoc]] StableDiffusion3Pipeline

src/diffusers/loaders/single_file_utils.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
7676
"stable_cascade_stage_c": "clip_txt_mapper.weight",
7777
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
78+
"sd35_large": "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
7879
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
7980
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
8081
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
@@ -113,6 +114,9 @@
113114
"sd3": {
114115
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
115116
},
117+
"sd35_large": {
118+
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-large",
119+
},
116120
"animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"},
117121
"animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
118122
"animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
@@ -504,9 +508,12 @@ def infer_diffusers_model_type(checkpoint):
504508
):
505509
model_type = "stable_cascade_stage_b"
506510

507-
elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint:
511+
elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["sd3"]].shape[-1] == 9216:
508512
model_type = "sd3"
509513

514+
elif CHECKPOINT_KEY_NAMES["sd35_large"] in checkpoint:
515+
model_type = "sd35_large"
516+
510517
elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
511518
if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint:
512519
model_type = "animatediff_scribble"
@@ -1670,6 +1677,22 @@ def swap_scale_shift(weight, dim):
16701677
return new_weight
16711678

16721679

1680+
def get_attn2_layers(state_dict):
1681+
attn2_layers = []
1682+
for key in state_dict.keys():
1683+
if "attn2." in key:
1684+
# Extract the layer number from the key
1685+
layer_num = int(key.split(".")[1])
1686+
attn2_layers.append(layer_num)
1687+
1688+
return tuple(sorted(set(attn2_layers)))
1689+
1690+
1691+
def get_caption_projection_dim(state_dict):
1692+
caption_projection_dim = state_dict["context_embedder.weight"].shape[0]
1693+
return caption_projection_dim
1694+
1695+
16731696
def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
16741697
converted_state_dict = {}
16751698
keys = list(checkpoint.keys())
@@ -1678,7 +1701,10 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
16781701
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
16791702

16801703
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401
1681-
caption_projection_dim = 1536
1704+
dual_attention_layers = get_attn2_layers(checkpoint)
1705+
1706+
caption_projection_dim = get_caption_projection_dim(checkpoint)
1707+
has_qk_norm = any("ln_q" in key for key in checkpoint.keys())
16821708

16831709
# Positional and patch embeddings.
16841710
converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed")
@@ -1735,6 +1761,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
17351761
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
17361762
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
17371763

1764+
# qk norm
1765+
if has_qk_norm:
1766+
converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = checkpoint.pop(
1767+
f"joint_blocks.{i}.x_block.attn.ln_q.weight"
1768+
)
1769+
converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = checkpoint.pop(
1770+
f"joint_blocks.{i}.x_block.attn.ln_k.weight"
1771+
)
1772+
converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = checkpoint.pop(
1773+
f"joint_blocks.{i}.context_block.attn.ln_q.weight"
1774+
)
1775+
converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = checkpoint.pop(
1776+
f"joint_blocks.{i}.context_block.attn.ln_k.weight"
1777+
)
1778+
17381779
# output projections.
17391780
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop(
17401781
f"joint_blocks.{i}.x_block.attn.proj.weight"
@@ -1750,6 +1791,38 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
17501791
f"joint_blocks.{i}.context_block.attn.proj.bias"
17511792
)
17521793

1794+
if i in dual_attention_layers:
1795+
# Q, K, V
1796+
sample_q2, sample_k2, sample_v2 = torch.chunk(
1797+
checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0
1798+
)
1799+
sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk(
1800+
checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0
1801+
)
1802+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2])
1803+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias])
1804+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2])
1805+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias])
1806+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2])
1807+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias])
1808+
1809+
# qk norm
1810+
if has_qk_norm:
1811+
converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = checkpoint.pop(
1812+
f"joint_blocks.{i}.x_block.attn2.ln_q.weight"
1813+
)
1814+
converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = checkpoint.pop(
1815+
f"joint_blocks.{i}.x_block.attn2.ln_k.weight"
1816+
)
1817+
1818+
# output projections.
1819+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(
1820+
f"joint_blocks.{i}.x_block.attn2.proj.weight"
1821+
)
1822+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(
1823+
f"joint_blocks.{i}.x_block.attn2.proj.bias"
1824+
)
1825+
17531826
# norms.
17541827
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
17551828
f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"

0 commit comments

Comments
 (0)