diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index 2d1fdb5f7d83..3f85b9f7a290 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -82,6 +82,8 @@ def gelu(self, gate: torch.Tensor) -> torch.Tensor: if gate.device.type == "mps" and is_torch_version("<", "2.0.0"): # fp16 gelu not supported on mps before torch 2.0 return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) + elif gate.device.type == "npu": + return torch_npu.npu_fast_gelu(gate) return F.gelu(gate, approximate=self.approximate) def forward(self, hidden_states): diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index ab0d7102ee83..95ed1f0940cf 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -893,6 +893,70 @@ def _sage_attention_backward_op( raise NotImplementedError("Backward pass is not implemented for Sage attention.") +def _npu_attention_forward_op( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, +): + # if enable_gqa: + # raise ValueError("`enable_gqa` is not yet supported for cuDNN attention.") + if return_lse: + raise ValueError("NPU attention backend does not support setting `return_lse=True`.") + + # tensors_to_save = () + + # Contiguous is a must here! Calling cuDNN backend with aten ops produces incorrect results + # if the input tensors are not contiguous. + query = query.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + # tensors_to_save += (query, key, value) + + out = npu_fusion_attention( + query, + key, + value, + query.size(1), # num_heads + input_layout="BNSD", + pse=None, + scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0 - dropout_p, + sync=False, + inner_precise=0, + )[0] + + # tensors_to_save += (out) + # if _save_ctx: + # ctx.save_for_backward(*tensors_to_save) + # ctx.dropout_p = dropout_p + # ctx.is_causal = is_causal + # ctx.scale = scale + # ctx.attn_mask = attn_mask + + out = out.transpose(1, 2).contiguous() + return out + +# backward declaration: +# aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) +def _npu_attention_backward_op( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + **kwargs, +): + raise NotImplementedError("Backward pass is not implemented for Npu Fusion Attention.") + # ===== Context parallel ===== @@ -1722,22 +1786,39 @@ def _native_npu_attention( ) -> torch.Tensor: if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") - query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) - out = npu_fusion_attention( - query, - key, - value, - query.size(1), # num_heads - input_layout="BNSD", - pse=None, - scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, - pre_tockens=65536, - next_tockens=65536, - keep_prob=1.0 - dropout_p, - sync=False, - inner_precise=0, - )[0] - out = out.transpose(1, 2).contiguous() + if _parallel_config is None: + query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) + out = npu_fusion_attention( + query, + key, + value, + query.size(1), # num_heads + input_layout="BNSD", + # input_layout="BSND", + pse=None, + scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0 - dropout_p, + sync=False, + inner_precise=0, + )[0] + out = out.transpose(1, 2).contiguous() + else: + out = _templated_context_parallel_attention( + query, + key, + value, + None, + dropout_p, + None, + scale, + None, + return_lse, + forward_op=_npu_attention_forward_op, + backward_op=_npu_attention_backward_op, + _parallel_config=_parallel_config, + ) return out diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 37fc412adcc3..2013f237e5c1 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -19,10 +19,13 @@ import torch.nn.functional as F from torch import nn -from ..utils import deprecate +from ..utils import deprecate, is_torch_npu_available from .activations import FP32SiLU, get_activation from .attention_processor import Attention +if is_torch_npu_available: + import torch_npu + def get_timestep_embedding( timesteps: torch.Tensor, @@ -1184,6 +1187,57 @@ def get_1d_rotary_pos_embed( return freqs_cis +def npu_apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, + sequence_dim: int = 2, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + if sequence_dim == 2: + cos = cos[None, None, :, :] + sin = sin[None, None, :, :] + elif sequence_dim == 1: + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + else: + raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.") + + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + rotary_mode = "interleave" + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen, CogView4 and Cosmos + rotary_mode = "half" + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + out = torch_npu.npu_rotary_mul(x, cos, sin, rotary_mode=rotary_mode).to(x.dtype) + + return out + else: + # used for lumina + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) + def apply_rotary_emb( x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], @@ -1205,6 +1259,12 @@ def apply_rotary_emb( Returns: Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ + if is_torch_npu_available: + return npu_apply_rotary_emb(x=x, + freqs_cis=freqs_cis, + use_real=use_real, + use_real_unbind_dim=use_real_unbind_dim, + sequence_dim=sequence_dim) if use_real: cos, sin = freqs_cis # [S, D] if sequence_dim == 2: diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index ae2a6298f5f7..0f10dae9d1a1 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -202,6 +202,79 @@ def forward( x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] return x, gate_msa +class AdaLayerNormZeroNpu(nn.Module): + r""" + Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True): + super().__init__() + + op_path="/root/lym/op_build/build/lib.linux-x86_64-cpython-311/ascend_ops.cpython-311-x86_64-linux-gnu.so" + torch.ops.load_library(op_path) + + if num_embeddings is not None: + self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) + else: + self.emb = None + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias) + + def forward( + self, + x: torch.Tensor, + timestep: Optional[torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + hidden_dtype: Optional[torch.dtype] = None, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if self.emb is not None: + emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype) + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) + x = torch.ops.ascend_ops.adalayernorm( + x=x, + scale=scale_msa[:, None], + shift=shift_msa[:, None], + epsilson=1e-6) + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + +class AdaLayerNormZeroSingleNpu(nn.Module): + r""" + Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): + super().__init__() + + op_path="/root/lym/op_build/build/lib.linux-x86_64-cpython-311/ascend_ops.cpython-311-x86_64-linux-gnu.so" + torch.ops.load_library(op_path) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias) + + def forward( + self, + x: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1) + x = torch.ops.ascend_ops.adalayernorm( + x=x, + scale= scale_msa[:, None], + shift=shift_msa[:, None], + epsilson=1e-6) + return x, gate_msa class LuminaRMSNormZero(nn.Module): """ @@ -351,6 +424,54 @@ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torc x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] return x +class AdaLayerNormContinuousNpu(nn.Module): + r""" + Adaptive normalization layer with a norm layer (layer_norm or rms_norm). + + Args: + embedding_dim (`int`): Embedding dimension to use during projection. + conditioning_embedding_dim (`int`): Dimension of the input condition. + elementwise_affine (`bool`, defaults to `True`): + Boolean flag to denote if affine transformation should be applied. + eps (`float`, defaults to 1e-5): Epsilon factor. + bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use. + norm_type (`str`, defaults to `"layer_norm"`): + Normalization layer to use. Values supported: "layer_norm", "rms_norm". + """ + + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + ): + super().__init__() + + op_path="/root/lym/op_build/build/lib.linux-x86_64-cpython-311/ascend_ops.cpython-311-x86_64-linux-gnu.so" + torch.ops.load_library(op_path) + + self.eps = eps + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = torch.ops.ascend_ops.adalayernorm( + x=x, + scale= scale[:, None, :], + shift=shift[:, None, :], + epsilson=self.eps) + return x class LuminaLayerNormContinuous(nn.Module): def __init__( diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 16c526f437f2..4ea0767d054e 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -36,8 +36,15 @@ ) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle - +if is_torch_npu_available: + from ..normalization import AdaLayerNormContinuousNpu as AdaLayerNormContinuous + from ..normalization import AdaLayerNormZeroNpu as AdaLayerNormZero + from ..normalization import AdaLayerNormZeroSingleNpu as AdaLayerNormZeroSingle + + op_path="/root/lym/op_build/build/lib.linux-x86_64-cpython-311/ascend_ops.cpython-311-x86_64-linux-gnu.so" + torch.ops.load_library(op_path) +else: + from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -114,7 +121,7 @@ def __call__( if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) - + hidden_states = dispatch_attention_fn( query, key, @@ -272,6 +279,27 @@ def __call__( return hidden_states +class RMSNormNpu(nn.Module): + def __init__(self, hidden_size, eps=1e-6, elementwise_affine=True): + """ + RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + import torch_npu + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.fused_norm = torch_npu.npu_rms_norm + + def forward(self, hidden_states, if_fused=True): + if if_fused: + return self.fused_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0] + else: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + class FluxAttention(torch.nn.Module, AttentionModuleMixin): _default_processor_cls = FluxAttnProcessor _available_processors = [ @@ -309,9 +337,14 @@ def __init__( self.heads = out_dim // dim_head if out_dim is not None else heads self.added_kv_proj_dim = added_kv_proj_dim self.added_proj_bias = added_proj_bias + + if is_torch_npu_available: + RMSNorm = RMSNormNpu + else: + RMSNorm = torch.nn.RMSNorm - self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) - self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) @@ -322,8 +355,8 @@ def __init__( self.to_out.append(torch.nn.Dropout(dropout)) if added_kv_proj_dim is not None: - self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) - self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) @@ -360,7 +393,11 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, self.norm = AdaLayerNormZeroSingle(dim) self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) - self.act_mlp = nn.GELU(approximate="tanh") + if is_torch_npu_available: + import torch_npu + self.act_mlp = torch_npu.npu_fast_gelu + else: + self.act_mlp = nn.GELU(approximate="tanh") self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) self.attn = FluxAttention( @@ -432,6 +469,7 @@ def __init__( self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") def forward( @@ -466,8 +504,15 @@ def forward( attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = hidden_states + attn_output - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + if is_torch_npu_available: + norm_hidden_states = torch.ops.ascend_ops.adalayernorm( + x=hidden_states, + scale=scale_mlp[:, None], + shift=shift_mlp[:, None], + epsilson=1e-6) + else: + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] ff_output = self.ff(norm_hidden_states) ff_output = gate_mlp.unsqueeze(1) * ff_output @@ -480,8 +525,15 @@ def forward( context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output encoder_hidden_states = encoder_hidden_states + context_attn_output - norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) - norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + if is_torch_npu_available: + norm_encoder_hidden_states = torch.ops.ascend_ops.adalayernorm( + x=encoder_hidden_states, + scale=c_scale_mlp[:, None], + shift=c_shift_mlp[:, None], + epsilson=1e-6) + else: + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] context_ff_output = self.ff_context(norm_encoder_hidden_states) encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output @@ -633,6 +685,7 @@ def __init__( self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) self.gradient_checkpointing = False + self.image_rotary_emb = None def forward( self, @@ -717,11 +770,12 @@ def forward( img_ids = img_ids[0] ids = torch.cat((txt_ids, img_ids), dim=0) - if is_torch_npu_available(): - freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) - image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) - else: - image_rotary_emb = self.pos_embed(ids) + if self.image_rotary_emb is None: + if is_torch_npu_available(): + freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) + image_rotary_emb = (freqs_cos.npu().to(hidden_states.dtype), freqs_sin.npu().to(hidden_states.dtype)) + else: + image_rotary_emb = self.pos_embed(ids) if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 55c261ab2f29..ef1673c057ce 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -305,6 +305,7 @@ def from_pretrained( "cache_dir", "force_download", "local_files_only", + "local_dir", "proxies", "resume_download", "revision", @@ -331,7 +332,6 @@ def from_pretrained( module_file=module_file, class_name=class_name, **hub_kwargs, - **kwargs, ) expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls) block_kwargs = { diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index f67a0e211281..d605eac1f2b1 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -355,7 +355,7 @@ def _get_clip_prompt_embeds( prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py index 68984da4dc40..9d0158c6b654 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py @@ -373,7 +373,7 @@ def _get_clip_prompt_embeds( prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py index bc281428e257..941b675099b9 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py @@ -326,7 +326,7 @@ def _get_clip_prompt_embeds( prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py index 22a8dac238f5..f40dd52fc244 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py @@ -342,7 +342,7 @@ def _get_clip_prompt_embeds( prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 3b7b26dc636c..660d9801df56 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -336,7 +336,7 @@ def _get_clip_prompt_embeds( prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index db047f19924d..9b11bc8781e7 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -361,7 +361,7 @@ def _get_clip_prompt_embeds( prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py index c95fa530c8d7..b947cbff0914 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py @@ -367,7 +367,7 @@ def _get_clip_prompt_embeds( prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py index 627b1e0604dc..b2ef5a29e03e 100644 --- a/src/diffusers/utils/dynamic_modules_utils.py +++ b/src/diffusers/utils/dynamic_modules_utils.py @@ -254,6 +254,7 @@ def get_cached_module_file( token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, local_files_only: bool = False, + local_dir: Optional[str] = None, ): """ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached @@ -332,6 +333,7 @@ def get_cached_module_file( force_download=force_download, proxies=proxies, local_files_only=local_files_only, + local_dir=local_dir, ) submodule = "git" module_file = pretrained_model_name_or_path + ".py" @@ -355,6 +357,7 @@ def get_cached_module_file( force_download=force_download, proxies=proxies, local_files_only=local_files_only, + local_dir=local_dir, token=token, ) submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/"))) @@ -415,6 +418,7 @@ def get_cached_module_file( token=token, revision=revision, local_files_only=local_files_only, + local_dir=local_dir, ) return os.path.join(full_submodule, module_file) @@ -431,7 +435,7 @@ def get_class_from_dynamic_module( token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, local_files_only: bool = False, - **kwargs, + local_dir: Optional[str] = None, ): """ Extracts a class from a module file, present in the local folder or repository of a model. @@ -496,5 +500,6 @@ def get_class_from_dynamic_module( token=token, revision=revision, local_files_only=local_files_only, + local_dir=local_dir, ) return get_class_in_module(class_name, final_module)