diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index ab0d7102ee83..289c3e82955b 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -893,6 +893,72 @@ def _sage_attention_backward_op( raise NotImplementedError("Backward pass is not implemented for Sage attention.") +def _npu_attention_forward_op( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, +): + # if enable_gqa: + # raise ValueError("`enable_gqa` is not yet supported for cuDNN attention.") + if return_lse: + raise ValueError("NPU attention backend does not support setting `return_lse=True`.") + + # tensors_to_save = () + + # Contiguous is a must here! Calling cuDNN backend with aten ops produces incorrect results + # if the input tensors are not contiguous. + query = query.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + # tensors_to_save += (query, key, value) + + out = npu_fusion_attention( + query, + key, + value, + query.size(1), # num_heads + input_layout="BNSD", + pse=None, + scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0 - dropout_p, + sync=False, + inner_precise=0, + )[0] + + # tensors_to_save += (out) + # if _save_ctx: + # ctx.save_for_backward(*tensors_to_save) + # ctx.dropout_p = dropout_p + # ctx.is_causal = is_causal + # ctx.scale = scale + # ctx.attn_mask = attn_mask + + out = out.transpose(1, 2).contiguous() + return out + + +# backward declaration: +# aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) +def _npu_attention_backward_op( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + **kwargs, +): + raise NotImplementedError("Backward pass is not implemented for Npu Fusion Attention.") + + # ===== Context parallel ===== @@ -1722,22 +1788,39 @@ def _native_npu_attention( ) -> torch.Tensor: if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") - query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) - out = npu_fusion_attention( - query, - key, - value, - query.size(1), # num_heads - input_layout="BNSD", - pse=None, - scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, - pre_tockens=65536, - next_tockens=65536, - keep_prob=1.0 - dropout_p, - sync=False, - inner_precise=0, - )[0] - out = out.transpose(1, 2).contiguous() + if _parallel_config is None: + query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) + out = npu_fusion_attention( + query, + key, + value, + query.size(1), # num_heads + input_layout="BNSD", + # input_layout="BSND", + pse=None, + scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0 - dropout_p, + sync=False, + inner_precise=0, + )[0] + out = out.transpose(1, 2).contiguous() + else: + out = _templated_context_parallel_attention( + query, + key, + value, + None, + dropout_p, + None, + scale, + None, + return_lse, + forward_op=_npu_attention_forward_op, + backward_op=_npu_attention_backward_op, + _parallel_config=_parallel_config, + ) return out