@@ -893,6 +893,72 @@ def _sage_attention_backward_op(
893893 raise NotImplementedError ("Backward pass is not implemented for Sage attention." )
894894
895895
896+ def _npu_attention_forward_op (
897+ ctx : torch .autograd .function .FunctionCtx ,
898+ query : torch .Tensor ,
899+ key : torch .Tensor ,
900+ value : torch .Tensor ,
901+ attn_mask : Optional [torch .Tensor ] = None ,
902+ dropout_p : float = 0.0 ,
903+ is_causal : bool = False ,
904+ scale : Optional [float ] = None ,
905+ enable_gqa : bool = False ,
906+ return_lse : bool = False ,
907+ _save_ctx : bool = True ,
908+ _parallel_config : Optional ["ParallelConfig" ] = None ,
909+ ):
910+ # if enable_gqa:
911+ # raise ValueError("`enable_gqa` is not yet supported for cuDNN attention.")
912+ if return_lse :
913+ raise ValueError ("NPU attention backend does not support setting `return_lse=True`." )
914+
915+ # tensors_to_save = ()
916+
917+ # Contiguous is a must here! Calling cuDNN backend with aten ops produces incorrect results
918+ # if the input tensors are not contiguous.
919+ query = query .transpose (1 , 2 ).contiguous ()
920+ key = key .transpose (1 , 2 ).contiguous ()
921+ value = value .transpose (1 , 2 ).contiguous ()
922+ # tensors_to_save += (query, key, value)
923+
924+ out = npu_fusion_attention (
925+ query ,
926+ key ,
927+ value ,
928+ query .size (1 ), # num_heads
929+ input_layout = "BNSD" ,
930+ pse = None ,
931+ scale = 1.0 / math .sqrt (query .shape [- 1 ]) if scale is None else scale ,
932+ pre_tockens = 65536 ,
933+ next_tockens = 65536 ,
934+ keep_prob = 1.0 - dropout_p ,
935+ sync = False ,
936+ inner_precise = 0 ,
937+ )[0 ]
938+
939+ # tensors_to_save += (out)
940+ # if _save_ctx:
941+ # ctx.save_for_backward(*tensors_to_save)
942+ # ctx.dropout_p = dropout_p
943+ # ctx.is_causal = is_causal
944+ # ctx.scale = scale
945+ # ctx.attn_mask = attn_mask
946+
947+ out = out .transpose (1 , 2 ).contiguous ()
948+ return out
949+
950+
951+ # backward declaration:
952+ # aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
953+ def _npu_attention_backward_op (
954+ ctx : torch .autograd .function .FunctionCtx ,
955+ grad_out : torch .Tensor ,
956+ * args ,
957+ ** kwargs ,
958+ ):
959+ raise NotImplementedError ("Backward pass is not implemented for Npu Fusion Attention." )
960+
961+
896962# ===== Context parallel =====
897963
898964
@@ -1722,22 +1788,39 @@ def _native_npu_attention(
17221788) -> torch .Tensor :
17231789 if return_lse :
17241790 raise ValueError ("NPU attention backend does not support setting `return_lse=True`." )
1725- query , key , value = (x .transpose (1 , 2 ).contiguous () for x in (query , key , value ))
1726- out = npu_fusion_attention (
1727- query ,
1728- key ,
1729- value ,
1730- query .size (1 ), # num_heads
1731- input_layout = "BNSD" ,
1732- pse = None ,
1733- scale = 1.0 / math .sqrt (query .shape [- 1 ]) if scale is None else scale ,
1734- pre_tockens = 65536 ,
1735- next_tockens = 65536 ,
1736- keep_prob = 1.0 - dropout_p ,
1737- sync = False ,
1738- inner_precise = 0 ,
1739- )[0 ]
1740- out = out .transpose (1 , 2 ).contiguous ()
1791+ if _parallel_config is None :
1792+ query , key , value = (x .transpose (1 , 2 ).contiguous () for x in (query , key , value ))
1793+ out = npu_fusion_attention (
1794+ query ,
1795+ key ,
1796+ value ,
1797+ query .size (1 ), # num_heads
1798+ input_layout = "BNSD" ,
1799+ # input_layout="BSND",
1800+ pse = None ,
1801+ scale = 1.0 / math .sqrt (query .shape [- 1 ]) if scale is None else scale ,
1802+ pre_tockens = 65536 ,
1803+ next_tockens = 65536 ,
1804+ keep_prob = 1.0 - dropout_p ,
1805+ sync = False ,
1806+ inner_precise = 0 ,
1807+ )[0 ]
1808+ out = out .transpose (1 , 2 ).contiguous ()
1809+ else :
1810+ out = _templated_context_parallel_attention (
1811+ query ,
1812+ key ,
1813+ value ,
1814+ None ,
1815+ dropout_p ,
1816+ None ,
1817+ scale ,
1818+ None ,
1819+ return_lse ,
1820+ forward_op = _npu_attention_forward_op ,
1821+ backward_op = _npu_attention_backward_op ,
1822+ _parallel_config = _parallel_config ,
1823+ )
17411824 return out
17421825
17431826
0 commit comments