Skip to content

[xnn update prep] deprecate sdpa #11506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion backends/xnnpack/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
op_quant_dequant,
op_relu,
op_rsqrt,
op_sdpa,
op_sigmoid,
op_skip_ops,
op_slice_copy,
Expand Down
111 changes: 0 additions & 111 deletions backends/xnnpack/operators/op_sdpa.py

This file was deleted.

2 changes: 0 additions & 2 deletions backends/xnnpack/partition/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
QuantizedPerTensorConfig,
ReciprocalSquareRootConfig,
ReLUConfig,
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
SigmoidConfig,
SliceCopyConfig,
SoftmaxConfig,
Expand Down Expand Up @@ -99,7 +98,6 @@
PreluConfig,
ReciprocalSquareRootConfig,
ReLUConfig,
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
SigmoidConfig,
SliceCopyConfig,
SoftmaxConfig,
Expand Down
30 changes: 0 additions & 30 deletions backends/xnnpack/partition/config/generic_node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,33 +527,3 @@ class BMMConfig(GenericNodePartitionerConfig):

def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]


class SDPAConfig(GenericNodePartitionerConfig):
target_name = "scaled_dot_product_attention.default"

def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
"""
Requires Mask to have Rank 2
"""
if not self.check_common_constraints(node, ep):
return False

if len(node.all_input_nodes) < 4:
return False
mask_node = node.all_input_nodes[3]
mask_rank = mask_node.meta["val"].dim()
if mask_rank != 2:
why(
node,
reason=f"mask must have rank 2, got mask of rank {mask_rank}",
)
return False

return True

def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
return torch.ops.aten.scaled_dot_product_attention.default

def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]
37 changes: 0 additions & 37 deletions backends/xnnpack/runtime/XNNCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1961,42 +1961,6 @@ Error defineStaticSliceNode(
return Error::Ok;
}

/*
Defines Scaled Dot Product Attention (SDPA) node into the subgraph,
using the remapped ids to map the serialized ids,
to the new ids generated when defining the tensor value
*/
Error defineScaledDotProductAttentionNode(
xnn_subgraph_t subgraph_ptr,
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
const NodePtr node,
const fb_xnnpack::XNNGraph* graph) noexcept {
MAYBE_UNUSED(graph);

auto graph_node = node->xnode_union_as_XNNScaledDotProductAttention();

xnn_status status = xnn_define_scaled_dot_product_attention(
subgraph_ptr,
xnn_attention_logits_cap_type_none, // cap_type
nullptr, // cap_value - not used
remapped_ids.at(graph_node->query_id()),
remapped_ids.at(graph_node->key_id()),
remapped_ids.at(graph_node->value_id()),
remapped_ids.at(graph_node->scale_id()),
remapped_ids.at(graph_node->mask_id()),
remapped_ids.at(graph_node->output_id()),
graph_node->flags());

ET_CHECK_OR_RETURN_ERROR(
status == xnn_status_success,
Internal,
"Failed to create SDPA node %i with code: %s",
node->debug_handle(),
xnn_status_to_string(status));

return Error::Ok;
}

/*
Defines batch matrix multiply node into the subgraph,
using the remapped ids to map the serialized ids,
Expand Down Expand Up @@ -2097,7 +2061,6 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
_DEFINE(Concatenate4)
_DEFINE(Concatenate5)
_DEFINE(StaticSlice)
_DEFINE(ScaledDotProductAttention)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not updating schema for marking deprecated? XNNScaledDotProductAttention

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so XNNPACK removed the operator from their codebase, so for next update we need to delete. I can mark the operator in the schema as deprecated though.

_DEFINE(BatchMatrixMultiply)
case fb_xnnpack::XNodeUnion::NONE:
default: // Adding here as a catch all, just in case
Expand Down
130 changes: 0 additions & 130 deletions backends/xnnpack/test/ops/test_sdpa.py

This file was deleted.

Loading