Skip to content

Commit 59ec7b4

Browse files
authored
Merge pull request Eco-Sphere#1 from TmacAaron/dev
Flux&Wan-T2V Ascend Dev
2 parents 9f3c0fd + 6f7583b commit 59ec7b4

File tree

7 files changed

+374
-29
lines changed

7 files changed

+374
-29
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 99 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,72 @@ def _sage_attention_backward_op(
893893
raise NotImplementedError("Backward pass is not implemented for Sage attention.")
894894

895895

896+
def _npu_attention_forward_op(
897+
ctx: torch.autograd.function.FunctionCtx,
898+
query: torch.Tensor,
899+
key: torch.Tensor,
900+
value: torch.Tensor,
901+
attn_mask: Optional[torch.Tensor] = None,
902+
dropout_p: float = 0.0,
903+
is_causal: bool = False,
904+
scale: Optional[float] = None,
905+
enable_gqa: bool = False,
906+
return_lse: bool = False,
907+
_save_ctx: bool = True,
908+
_parallel_config: Optional["ParallelConfig"] = None,
909+
):
910+
# if enable_gqa:
911+
# raise ValueError("`enable_gqa` is not yet supported for cuDNN attention.")
912+
if return_lse:
913+
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
914+
915+
# tensors_to_save = ()
916+
917+
# Contiguous is a must here! Calling cuDNN backend with aten ops produces incorrect results
918+
# if the input tensors are not contiguous.
919+
query = query.transpose(1, 2).contiguous()
920+
key = key.transpose(1, 2).contiguous()
921+
value = value.transpose(1, 2).contiguous()
922+
# tensors_to_save += (query, key, value)
923+
924+
out = npu_fusion_attention(
925+
query,
926+
key,
927+
value,
928+
query.size(1), # num_heads
929+
input_layout="BNSD",
930+
pse=None,
931+
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
932+
pre_tockens=65536,
933+
next_tockens=65536,
934+
keep_prob=1.0 - dropout_p,
935+
sync=False,
936+
inner_precise=0,
937+
)[0]
938+
939+
# tensors_to_save += (out)
940+
# if _save_ctx:
941+
# ctx.save_for_backward(*tensors_to_save)
942+
# ctx.dropout_p = dropout_p
943+
# ctx.is_causal = is_causal
944+
# ctx.scale = scale
945+
# ctx.attn_mask = attn_mask
946+
947+
out = out.transpose(1, 2).contiguous()
948+
return out
949+
950+
951+
# backward declaration:
952+
# aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
953+
def _npu_attention_backward_op(
954+
ctx: torch.autograd.function.FunctionCtx,
955+
grad_out: torch.Tensor,
956+
*args,
957+
**kwargs,
958+
):
959+
raise NotImplementedError("Backward pass is not implemented for Npu Fusion Attention.")
960+
961+
896962
# ===== Context parallel =====
897963

898964

@@ -1722,22 +1788,39 @@ def _native_npu_attention(
17221788
) -> torch.Tensor:
17231789
if return_lse:
17241790
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
1725-
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
1726-
out = npu_fusion_attention(
1727-
query,
1728-
key,
1729-
value,
1730-
query.size(1), # num_heads
1731-
input_layout="BNSD",
1732-
pse=None,
1733-
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
1734-
pre_tockens=65536,
1735-
next_tockens=65536,
1736-
keep_prob=1.0 - dropout_p,
1737-
sync=False,
1738-
inner_precise=0,
1739-
)[0]
1740-
out = out.transpose(1, 2).contiguous()
1791+
if _parallel_config is None:
1792+
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
1793+
out = npu_fusion_attention(
1794+
query,
1795+
key,
1796+
value,
1797+
query.size(1), # num_heads
1798+
input_layout="BNSD",
1799+
# input_layout="BSND",
1800+
pse=None,
1801+
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
1802+
pre_tockens=65536,
1803+
next_tockens=65536,
1804+
keep_prob=1.0 - dropout_p,
1805+
sync=False,
1806+
inner_precise=0,
1807+
)[0]
1808+
out = out.transpose(1, 2).contiguous()
1809+
else:
1810+
out = _templated_context_parallel_attention(
1811+
query,
1812+
key,
1813+
value,
1814+
None,
1815+
dropout_p,
1816+
None,
1817+
scale,
1818+
None,
1819+
return_lse,
1820+
forward_op=_npu_attention_forward_op,
1821+
backward_op=_npu_attention_backward_op,
1822+
_parallel_config=_parallel_config,
1823+
)
17411824
return out
17421825

17431826

0 commit comments

Comments
 (0)