Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/models/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
# fp16 gelu not supported on mps before torch 2.0
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
elif gate.device.type == "npu":
return torch_npu.npu_fast_gelu(gate)
return F.gelu(gate, approximate=self.approximate)

def forward(self, hidden_states):
Expand Down
113 changes: 97 additions & 16 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,70 @@ def _sage_attention_backward_op(
raise NotImplementedError("Backward pass is not implemented for Sage attention.")


def _npu_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
# if enable_gqa:
# raise ValueError("`enable_gqa` is not yet supported for cuDNN attention.")
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")

# tensors_to_save = ()

# Contiguous is a must here! Calling cuDNN backend with aten ops produces incorrect results
# if the input tensors are not contiguous.
query = query.transpose(1, 2).contiguous()
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
# tensors_to_save += (query, key, value)

out = npu_fusion_attention(
query,
key,
value,
query.size(1), # num_heads
input_layout="BNSD",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0 - dropout_p,
sync=False,
inner_precise=0,
)[0]

# tensors_to_save += (out)
# if _save_ctx:
# ctx.save_for_backward(*tensors_to_save)
# ctx.dropout_p = dropout_p
# ctx.is_causal = is_causal
# ctx.scale = scale
# ctx.attn_mask = attn_mask

out = out.transpose(1, 2).contiguous()
return out

# backward declaration:
# aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
def _npu_attention_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
raise NotImplementedError("Backward pass is not implemented for Npu Fusion Attention.")

# ===== Context parallel =====


Expand Down Expand Up @@ -1722,22 +1786,39 @@ def _native_npu_attention(
) -> torch.Tensor:
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
out = npu_fusion_attention(
query,
key,
value,
query.size(1), # num_heads
input_layout="BNSD",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0 - dropout_p,
sync=False,
inner_precise=0,
)[0]
out = out.transpose(1, 2).contiguous()
if _parallel_config is None:
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
out = npu_fusion_attention(
query,
key,
value,
query.size(1), # num_heads
input_layout="BNSD",
# input_layout="BSND",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0 - dropout_p,
sync=False,
inner_precise=0,
)[0]
out = out.transpose(1, 2).contiguous()
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
dropout_p,
None,
scale,
None,
return_lse,
forward_op=_npu_attention_forward_op,
backward_op=_npu_attention_backward_op,
_parallel_config=_parallel_config,
)
return out


Expand Down
62 changes: 61 additions & 1 deletion src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
import torch.nn.functional as F
from torch import nn

from ..utils import deprecate
from ..utils import deprecate, is_torch_npu_available
from .activations import FP32SiLU, get_activation
from .attention_processor import Attention

if is_torch_npu_available:
import torch_npu


def get_timestep_embedding(
timesteps: torch.Tensor,
Expand Down Expand Up @@ -1184,6 +1187,57 @@ def get_1d_rotary_pos_embed(
return freqs_cis


def npu_apply_rotary_emb(
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
use_real: bool = True,
use_real_unbind_dim: int = -1,
sequence_dim: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
if use_real:
cos, sin = freqs_cis # [S, D]
if sequence_dim == 2:
cos = cos[None, None, :, :]
sin = sin[None, None, :, :]
elif sequence_dim == 1:
cos = cos[None, :, None, :]
sin = sin[None, :, None, :]
else:
raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")

cos, sin = cos.to(x.device), sin.to(x.device)

if use_real_unbind_dim == -1:
# Used for flux, cogvideox, hunyuan-dit
rotary_mode = "interleave"
elif use_real_unbind_dim == -2:
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
rotary_mode = "half"
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
out = torch_npu.npu_rotary_mul(x, cos, sin, rotary_mode=rotary_mode).to(x.dtype)

return out
else:
# used for lumina
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)

return x_out.type_as(x)

def apply_rotary_emb(
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
Expand All @@ -1205,6 +1259,12 @@ def apply_rotary_emb(
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
if is_torch_npu_available:
return npu_apply_rotary_emb(x=x,
freqs_cis=freqs_cis,
use_real=use_real,
use_real_unbind_dim=use_real_unbind_dim,
sequence_dim=sequence_dim)
if use_real:
cos, sin = freqs_cis # [S, D]
if sequence_dim == 2:
Expand Down
121 changes: 121 additions & 0 deletions src/diffusers/models/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,79 @@ def forward(
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa

class AdaLayerNormZeroNpu(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).

Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""

def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
super().__init__()

op_path="/root/lym/op_build/build/lib.linux-x86_64-cpython-311/ascend_ops.cpython-311-x86_64-linux-gnu.so"
torch.ops.load_library(op_path)

if num_embeddings is not None:
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
else:
self.emb = None

self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)

def forward(
self,
x: torch.Tensor,
timestep: Optional[torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
hidden_dtype: Optional[torch.dtype] = None,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if self.emb is not None:
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
x = torch.ops.ascend_ops.adalayernorm(
x=x,
scale=scale_msa[:, None],
shift=shift_msa[:, None],
epsilson=1e-6)
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp

class AdaLayerNormZeroSingleNpu(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).

Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""

def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
super().__init__()

op_path="/root/lym/op_build/build/lib.linux-x86_64-cpython-311/ascend_ops.cpython-311-x86_64-linux-gnu.so"
torch.ops.load_library(op_path)

self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)

def forward(
self,
x: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
x = torch.ops.ascend_ops.adalayernorm(
x=x,
scale= scale_msa[:, None],
shift=shift_msa[:, None],
epsilson=1e-6)
return x, gate_msa

class LuminaRMSNormZero(nn.Module):
"""
Expand Down Expand Up @@ -351,6 +424,54 @@ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torc
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x

class AdaLayerNormContinuousNpu(nn.Module):
r"""
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).

Args:
embedding_dim (`int`): Embedding dimension to use during projection.
conditioning_embedding_dim (`int`): Dimension of the input condition.
elementwise_affine (`bool`, defaults to `True`):
Boolean flag to denote if affine transformation should be applied.
eps (`float`, defaults to 1e-5): Epsilon factor.
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
norm_type (`str`, defaults to `"layer_norm"`):
Normalization layer to use. Values supported: "layer_norm", "rms_norm".
"""

def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
# However, this is how it was implemented in the original code, and it's rather likely you should
# set `elementwise_affine` to False.
elementwise_affine=True,
eps=1e-5,
bias=True,
norm_type="layer_norm",
):
super().__init__()

op_path="/root/lym/op_build/build/lib.linux-x86_64-cpython-311/ascend_ops.cpython-311-x86_64-linux-gnu.so"
torch.ops.load_library(op_path)

self.eps = eps
self.silu = nn.SiLU()
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)

def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
scale, shift = torch.chunk(emb, 2, dim=1)
x = torch.ops.ascend_ops.adalayernorm(
x=x,
scale= scale[:, None, :],
shift=shift[:, None, :],
epsilson=self.eps)
return x

class LuminaLayerNormContinuous(nn.Module):
def __init__(
Expand Down
Loading