Skip to content

Propose to refactor output normalization in several transformers #11850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
16 changes: 12 additions & 4 deletions scripts/convert_wan_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
"time_projection.1": "condition_embedder.time_proj",
"head.modulation": "scale_shift_table",
"head.modulation": "norm_out.linear.weight",
"head.head": "proj_out",
"modulation": "scale_shift_table",
"ffn.0": "ffn.net.0.proj",
Expand Down Expand Up @@ -67,7 +67,7 @@
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
"time_projection.1": "condition_embedder.time_proj",
"head.modulation": "scale_shift_table",
"head.modulation": "norm_out.linear.weight",
"head.head": "proj_out",
"modulation": "scale_shift_table",
"ffn.0": "ffn.net.0.proj",
Expand Down Expand Up @@ -105,8 +105,12 @@
"after_proj": "proj_out",
}

TRANSFORMER_SPECIAL_KEYS_REMAP = {}
VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"norm_out.linear.bias": lambda key, state_dict: state_dict.setdefault(key, torch.zeros(state_dict["norm_out.linear.weight"].shape[0]))
}
VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {
"norm_out.linear.bias": lambda key, state_dict: state_dict.setdefault(key, torch.zeros(state_dict["norm_out.linear.weight"].shape[0]))
}


def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
Expand Down Expand Up @@ -308,6 +312,10 @@ def convert_transformer(model_type: str):
continue
handler_fn_inplace(key, original_state_dict)

for special_key, handler_fn_inplace in SPECIAL_KEYS_REMAP.items():
if special_key not in original_state_dict:
handler_fn_inplace(special_key, original_state_dict)

transformer.load_state_dict(original_state_dict, strict=True, assign=True)
return transformer

Expand Down
17 changes: 10 additions & 7 deletions src/diffusers/models/transformers/latte_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
from ..normalization import AdaLayerNorm, AdaLayerNormSingle


class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["norm_out"]

"""
A 3D Transformer model for video-like data, paper: https://huggingface.co/papers/2401.03048, official code:
Expand Down Expand Up @@ -149,8 +150,13 @@ def __init__(

# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
self.norm_out = AdaLayerNorm(
embedding_dim=inner_dim,
output_dim=2 * inner_dim,
norm_elementwise_affine=False,
norm_eps=1e-6,
chunk_dim=1,
)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)

# 5. Latte other blocks.
Expand Down Expand Up @@ -305,10 +311,7 @@ def forward(
embedded_timestep = embedded_timestep.repeat_interleave(
num_frame, dim=0, output_size=embedded_timestep.shape[0] * num_frame
).view(-1, embedded_timestep.shape[-1])
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.norm_out(hidden_states, temb=embedded_timestep)
hidden_states = self.proj_out(hidden_states)

# unpatchify
Expand Down
20 changes: 10 additions & 10 deletions src/diffusers/models/transformers/pixart_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
from ..normalization import AdaLayerNorm, AdaLayerNormSingle


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -78,7 +78,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
"""

_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed", "norm_out"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"]

@register_to_config
Expand Down Expand Up @@ -171,8 +171,13 @@ def __init__(
)

# 3. Output blocks.
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
self.norm_out = AdaLayerNorm(
embedding_dim=self.inner_dim,
output_dim=2 * self.inner_dim,
norm_elementwise_affine=False,
norm_eps=1e-6,
chunk_dim=1,
)
self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels)

self.adaln_single = AdaLayerNormSingle(
Expand Down Expand Up @@ -406,12 +411,7 @@ def forward(
)

# 3. Output
shift, scale = (
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
hidden_states = self.norm_out(hidden_states, temb=embedded_timestep)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)

Expand Down
18 changes: 10 additions & 8 deletions src/diffusers/models/transformers/transformer_allegro.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
from ..normalization import AdaLayerNorm, AdaLayerNormSingle


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -175,6 +175,7 @@ def forward(

class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["norm_out"]

"""
A 3D Transformer model for video-like data.
Expand Down Expand Up @@ -292,8 +293,13 @@ def __init__(
)

# 3. Output projection & norm
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
self.norm_out = AdaLayerNorm(
embedding_dim=self.inner_dim,
output_dim=2 * self.inner_dim,
norm_elementwise_affine=False,
norm_eps=1e-6,
chunk_dim=1,
)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels)

# 4. Timestep embeddings
Expand Down Expand Up @@ -393,11 +399,7 @@ def forward(
)

# 4. Output normalization & projection
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)

# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.norm_out(hidden_states, temb=embedded_timestep)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)

Expand Down
18 changes: 10 additions & 8 deletions src/diffusers/models/transformers/transformer_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ..embeddings import PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle, RMSNorm
from ..normalization import AdaLayerNorm, AdaLayerNormSingle, RMSNorm


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -328,6 +328,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin

_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"]
_no_split_modules = ["norm_out"]
_repeated_blocks = ["LTXVideoTransformerBlock"]

@register_to_config
Expand Down Expand Up @@ -356,7 +357,6 @@ def __init__(

self.proj_in = nn.Linear(in_channels, inner_dim)

self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
self.time_embed = AdaLayerNormSingle(inner_dim, use_additional_conditions=False)

self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
Expand Down Expand Up @@ -389,7 +389,13 @@ def __init__(
]
)

self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
self.norm_out = AdaLayerNorm(
embedding_dim=inner_dim,
output_dim=2 * inner_dim,
norm_elementwise_affine=False,
norm_eps=1e-6,
chunk_dim=1,
)
self.proj_out = nn.Linear(inner_dim, out_channels)

self.gradient_checkpointing = False
Expand Down Expand Up @@ -464,11 +470,7 @@ def forward(
encoder_attention_mask=encoder_attention_mask,
)

scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]

hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.norm_out(hidden_states, temb=embedded_timestep.squeeze(1))
output = self.proj_out(hidden_states)

if USE_PEFT_BACKEND:
Expand Down
24 changes: 10 additions & 14 deletions src/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import FP32LayerNorm
from ..normalization import AdaLayerNorm, FP32LayerNorm


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -370,7 +370,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi

_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
_no_split_modules = ["WanTransformerBlock"]
_no_split_modules = ["WanTransformerBlock", "norm_out"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_repeated_blocks = ["WanTransformerBlock"]
Expand Down Expand Up @@ -426,9 +426,14 @@ def __init__(
)

# 4. Output norm & projection
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
self.norm_out = AdaLayerNorm(
embedding_dim=inner_dim,
output_dim=2 * inner_dim,
norm_elementwise_affine=False,
norm_eps=eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)

self.gradient_checkpointing = False

Expand Down Expand Up @@ -486,16 +491,7 @@ def forward(
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)

# 5. Output norm, projection & unpatchify
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)

# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
# first device rather than the last device, which hidden_states ends up
# on.
shift = shift.to(hidden_states.device)
scale = scale.to(hidden_states.device)

hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
hidden_states = self.norm_out(hidden_states, temb=temb)
hidden_states = self.proj_out(hidden_states)

hidden_states = hidden_states.reshape(
Expand Down
24 changes: 10 additions & 14 deletions src/diffusers/models/transformers/transformer_wan_vace.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..cache_utils import CacheMixin
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import FP32LayerNorm
from ..normalization import AdaLayerNorm, FP32LayerNorm
from .transformer_wan import WanAttnProcessor2_0, WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock


Expand Down Expand Up @@ -179,7 +179,7 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO

_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["patch_embedding", "vace_patch_embedding", "condition_embedder", "norm"]
_no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock"]
_no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock", "norm_out"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]

Expand Down Expand Up @@ -259,9 +259,14 @@ def __init__(
)

# 4. Output norm & projection
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
self.norm_out = AdaLayerNorm(
embedding_dim=inner_dim,
output_dim=2 * inner_dim,
norm_elementwise_affine=False,
norm_eps=eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)

self.gradient_checkpointing = False

Expand Down Expand Up @@ -365,16 +370,7 @@ def forward(
hidden_states = hidden_states + control_hint * scale

# 6. Output norm, projection & unpatchify
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)

# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
# first device rather than the last device, which hidden_states ends up
# on.
shift = shift.to(hidden_states.device)
scale = scale.to(hidden_states.device)

hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
hidden_states = self.norm_out(hidden_states, temb=temb)
hidden_states = self.proj_out(hidden_states)

hidden_states = hidden_states.reshape(
Expand Down