Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 99 additions & 16 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,72 @@ def _sage_attention_backward_op(
raise NotImplementedError("Backward pass is not implemented for Sage attention.")


def _npu_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
# if enable_gqa:
# raise ValueError("`enable_gqa` is not yet supported for cuDNN attention.")
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")

# tensors_to_save = ()

# Contiguous is a must here! Calling cuDNN backend with aten ops produces incorrect results
# if the input tensors are not contiguous.
query = query.transpose(1, 2).contiguous()
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
# tensors_to_save += (query, key, value)

out = npu_fusion_attention(
query,
key,
value,
query.size(1), # num_heads
input_layout="BNSD",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0 - dropout_p,
sync=False,
inner_precise=0,
)[0]

# tensors_to_save += (out)
# if _save_ctx:
# ctx.save_for_backward(*tensors_to_save)
# ctx.dropout_p = dropout_p
# ctx.is_causal = is_causal
# ctx.scale = scale
# ctx.attn_mask = attn_mask

out = out.transpose(1, 2).contiguous()
return out


# backward declaration:
# aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
def _npu_attention_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
raise NotImplementedError("Backward pass is not implemented for Npu Fusion Attention.")


# ===== Context parallel =====


Expand Down Expand Up @@ -1722,22 +1788,39 @@ def _native_npu_attention(
) -> torch.Tensor:
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
out = npu_fusion_attention(
query,
key,
value,
query.size(1), # num_heads
input_layout="BNSD",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0 - dropout_p,
sync=False,
inner_precise=0,
)[0]
out = out.transpose(1, 2).contiguous()
if _parallel_config is None:
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
out = npu_fusion_attention(
query,
key,
value,
query.size(1), # num_heads
input_layout="BNSD",
# input_layout="BSND",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0 - dropout_p,
sync=False,
inner_precise=0,
)[0]
out = out.transpose(1, 2).contiguous()
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
dropout_p,
None,
scale,
None,
return_lse,
forward_op=_npu_attention_forward_op,
backward_op=_npu_attention_backward_op,
_parallel_config=_parallel_config,
)
return out


Expand Down