From ad82d03a2832bc8d8d61fe80a1c14980297935d1 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Mon, 9 Jun 2025 15:10:18 -0700 Subject: [PATCH 1/3] Update [ghstack-poisoned] --- backends/xnnpack/runtime/XNNCompiler.cpp | 8 ++++++++ backends/xnnpack/serialization/runtime_schema.fbs | 9 +++++++++ backends/xnnpack/serialization/schema.fbs | 9 +++++++++ backends/xnnpack/serialization/xnnpack_graph_schema.py | 4 ++++ 4 files changed, 30 insertions(+) diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 56d0508bef0..312cbc17b95 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -121,6 +121,14 @@ xnn_datatype getDataType(const DataType& data_type) { return xnn_datatype::xnn_datatype_qdint8; case DataType::xnn_datatype_qbint4: return xnn_datatype::xnn_datatype_qbint4; + case DataType::xnn_datatype_qpint8: + return xnn_datatype::xnn_datatype_qpint8; + case DataType::xnn_datatype_int32: + return xnn_datatype::xnn_datatype_int32; + case DataType::xnn_datatype_pfp32: + return xnn_datatype::xnn_datatype_pfp32; + case DataType::xnn_datatype_bf16: + return xnn_datatype::xnn_datatype_bf16; default: return xnn_datatype::xnn_datatype_invalid; } diff --git a/backends/xnnpack/serialization/runtime_schema.fbs b/backends/xnnpack/serialization/runtime_schema.fbs index d76c3c0807e..a0d44327912 100644 --- a/backends/xnnpack/serialization/runtime_schema.fbs +++ b/backends/xnnpack/serialization/runtime_schema.fbs @@ -29,6 +29,15 @@ enum XNNDatatype : short { xnn_datatype_qdint8 = 9, /// Quantized 4-bit signed integer with shared blockwise quantization parameters. xnn_datatype_qbint4 = 10, + /// Dynamically quantized 8-bit signed integers packed with their per-row + /// quantization parameters. + xnn_datatype_qpint8 = 11, + /// 32-bit signed integers. + xnn_datatype_int32 = 12, + /// IEEE754 single-precision packed floating-point. + xnn_datatype_pfp32 = 13, + /// BFloat16, i.e. the upper 16 bits of a float32. + xnn_datatype_bf16 = 14, } // type of quantization diff --git a/backends/xnnpack/serialization/schema.fbs b/backends/xnnpack/serialization/schema.fbs index 356df663dfc..eeab28154cc 100644 --- a/backends/xnnpack/serialization/schema.fbs +++ b/backends/xnnpack/serialization/schema.fbs @@ -29,6 +29,15 @@ enum XNNDatatype : short { xnn_datatype_qdint8 = 9, /// Quantized 4-bit signed integer with shared blockwise quantization parameters. xnn_datatype_qbint4 = 10, + /// Dynamically quantized 8-bit signed integers packed with their per-row + /// quantization parameters. + xnn_datatype_qpint8 = 11, + /// 32-bit signed integers. + xnn_datatype_int32 = 12, + /// IEEE754 single-precision packed floating-point. + xnn_datatype_pfp32 = 13, + /// BFloat16, i.e. the upper 16 bits of a float32. + xnn_datatype_bf16 = 14, } // type of quantization diff --git a/backends/xnnpack/serialization/xnnpack_graph_schema.py b/backends/xnnpack/serialization/xnnpack_graph_schema.py index b8b4ea7f02f..dc50fb47da4 100644 --- a/backends/xnnpack/serialization/xnnpack_graph_schema.py +++ b/backends/xnnpack/serialization/xnnpack_graph_schema.py @@ -419,6 +419,10 @@ class XNNDatatype(IntEnum): xnn_datatype_qcint4 = 8 xnn_datatype_qdint8 = 9 xnn_datatype_qbint4 = 10 + xnn_datatype_qpint8 = 11 + xnn_datatype_int32 = 12 + xnn_datatype_pfp32 = 13 + xnn_datatype_bf16 = 14 @dataclass From 425ca7e6ed29ad96678c0d7398ec56fd978e4819 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Mon, 9 Jun 2025 15:10:34 -0700 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- backends/xnnpack/operators/__init__.py | 1 - backends/xnnpack/operators/op_sdpa.py | 111 ------------------ backends/xnnpack/partition/config/__init__.py | 2 - .../partition/config/generic_node_configs.py | 30 ----- backends/xnnpack/runtime/XNNCompiler.cpp | 37 ------ 5 files changed, 181 deletions(-) delete mode 100644 backends/xnnpack/operators/op_sdpa.py diff --git a/backends/xnnpack/operators/__init__.py b/backends/xnnpack/operators/__init__.py index ec07502de54..a83f8706d94 100644 --- a/backends/xnnpack/operators/__init__.py +++ b/backends/xnnpack/operators/__init__.py @@ -39,7 +39,6 @@ op_quant_dequant, op_relu, op_rsqrt, - op_sdpa, op_sigmoid, op_skip_ops, op_slice_copy, diff --git a/backends/xnnpack/operators/op_sdpa.py b/backends/xnnpack/operators/op_sdpa.py deleted file mode 100644 index e0ec7b37b3b..00000000000 --- a/backends/xnnpack/operators/op_sdpa.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import cast, Dict - -import torch -from executorch.backends.transforms import get_shape -from executorch.backends.xnnpack.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( - XNNGraph, - XNNScaledDotProductAttention, - XNode, -) -from executorch.backends.xnnpack.utils.utils import get_input_node - - -@register_node_visitor -class SDPAVisitor(NodeVisitor): - target = "aten.scaled_dot_product_attention.default" - - def __init__(self, *args) -> None: - super().__init__(*args) - - @staticmethod - def get_fake_attr(name: str, value: torch.Tensor) -> torch.fx.Node: - g = torch.fx.Graph() - gm = torch.fx.GraphModule({}, g) - fake_node = torch.fx.Node(g, name, "get_attr", target=name, args=(), kwargs={}) - g._owning_module = gm - setattr(g._owning_module, name, value) - fake_node.meta["val"] = value - return fake_node - - def define_node( - self, - node: torch.fx.Node, - xnn_graph: XNNGraph, - vals_to_ids: Dict[torch.fx.Node, int], - debug_handle: int, - ) -> None: - # inputs - for i in range(0, 4): - inp = get_input_node(node, i) - self.define_tensor( - inp, - xnn_graph, - vals_to_ids, - ) - - # Make sure mask is not bool - mask_node = get_input_node(node, 3) - mask_dtype = mask_node.meta["val"].dtype - assert mask_dtype in [ - torch.float, - torch.float16, - ], "SDPA Mask must be a float (or half) tensor" - - # Make sure mask is not >2D - assert len(get_shape(mask_node)) == 2, "SDPA Mask must be 2D" - - # Hack to broadcast the scale - q_shape = get_shape(get_input_node(node, 0)) - embedding_dim = q_shape[-1] - scale = 1 / (embedding_dim**0.5) - if "scale" in node.kwargs and node.kwargs["scale"]: - scale = cast(float, node.kwargs["scale"]) - - t = torch.full((embedding_dim,), scale, dtype=mask_dtype) - scale_node = self.get_fake_attr("scale", t) - self.define_tensor( - scale_node, - xnn_graph, - vals_to_ids, - ) - - # outputs - outp = node - self.define_tensor( - outp, - xnn_graph, - vals_to_ids, - ) - - # ids - q_id = vals_to_ids[get_input_node(node, 0)] - k_id = vals_to_ids[get_input_node(node, 1)] - v_id = vals_to_ids[get_input_node(node, 2)] - mask_id = vals_to_ids[mask_node] - scale_id = vals_to_ids[scale_node] - output_id = vals_to_ids[outp] - - # Create a new node - sdpa_node = XNode( - xnode_union=XNNScaledDotProductAttention( - query_id=q_id, - key_id=k_id, - value_id=v_id, - scale_id=scale_id, - mask_id=mask_id, - output_id=output_id, - flags=0, - ), - debug_handle=debug_handle, - ) - xnn_graph.xnodes.append(sdpa_node) diff --git a/backends/xnnpack/partition/config/__init__.py b/backends/xnnpack/partition/config/__init__.py index 553b10f60d1..b304317b257 100644 --- a/backends/xnnpack/partition/config/__init__.py +++ b/backends/xnnpack/partition/config/__init__.py @@ -43,7 +43,6 @@ QuantizedPerTensorConfig, ReciprocalSquareRootConfig, ReLUConfig, - # SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails SigmoidConfig, SliceCopyConfig, SoftmaxConfig, @@ -99,7 +98,6 @@ PreluConfig, ReciprocalSquareRootConfig, ReLUConfig, - # SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails SigmoidConfig, SliceCopyConfig, SoftmaxConfig, diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index 46922e47010..a8846b68d60 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -527,33 +527,3 @@ class BMMConfig(GenericNodePartitionerConfig): def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32] - - -class SDPAConfig(GenericNodePartitionerConfig): - target_name = "scaled_dot_product_attention.default" - - def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: - """ - Requires Mask to have Rank 2 - """ - if not self.check_common_constraints(node, ep): - return False - - if len(node.all_input_nodes) < 4: - return False - mask_node = node.all_input_nodes[3] - mask_rank = mask_node.meta["val"].dim() - if mask_rank != 2: - why( - node, - reason=f"mask must have rank 2, got mask of rank {mask_rank}", - ) - return False - - return True - - def get_original_aten(self) -> Optional[torch._ops.OpOverload]: - return torch.ops.aten.scaled_dot_product_attention.default - - def supported_precision_types(self) -> List[ConfigPrecisionType]: - return [ConfigPrecisionType.FP32] diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 312cbc17b95..a364594fb1c 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -1961,42 +1961,6 @@ Error defineStaticSliceNode( return Error::Ok; } -/* -Defines Scaled Dot Product Attention (SDPA) node into the subgraph, -using the remapped ids to map the serialized ids, -to the new ids generated when defining the tensor value -*/ -Error defineScaledDotProductAttentionNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNScaledDotProductAttention(); - - xnn_status status = xnn_define_scaled_dot_product_attention( - subgraph_ptr, - xnn_attention_logits_cap_type_none, // cap_type - nullptr, // cap_value - not used - remapped_ids.at(graph_node->query_id()), - remapped_ids.at(graph_node->key_id()), - remapped_ids.at(graph_node->value_id()), - remapped_ids.at(graph_node->scale_id()), - remapped_ids.at(graph_node->mask_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create SDPA node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - /* Defines batch matrix multiply node into the subgraph, using the remapped ids to map the serialized ids, @@ -2097,7 +2061,6 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) { _DEFINE(Concatenate4) _DEFINE(Concatenate5) _DEFINE(StaticSlice) - _DEFINE(ScaledDotProductAttention) _DEFINE(BatchMatrixMultiply) case fb_xnnpack::XNodeUnion::NONE: default: // Adding here as a catch all, just in case From f3822ec6b23d901a1caeb5ec49f5f21295dc4858 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Mon, 23 Jun 2025 10:27:26 -0700 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- backends/xnnpack/test/ops/test_sdpa.py | 130 ------------------------- 1 file changed, 130 deletions(-) delete mode 100644 backends/xnnpack/test/ops/test_sdpa.py diff --git a/backends/xnnpack/test/ops/test_sdpa.py b/backends/xnnpack/test/ops/test_sdpa.py deleted file mode 100644 index 205b6d4ab36..00000000000 --- a/backends/xnnpack/test/ops/test_sdpa.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest -from typing import Optional - -import torch -from executorch.backends.xnnpack.partition.config.generic_node_configs import SDPAConfig -from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner -from executorch.backends.xnnpack.test.tester import Tester -from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower - - -class TestSDPA(unittest.TestCase): - def setUp(self): - torch._dynamo.reset() - - class SDPA(torch.nn.Module): - def __init__(self, scale: Optional[float] = None): - super().__init__() - self.dropout_p: float = 0.0 - self.is_causal: bool = False - self.scale = scale - - def forward( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - mask: Optional[torch.Tensor] = None, - ): - return torch.nn.functional.scaled_dot_product_attention( - q, - k, - v, - attn_mask=mask, - dropout_p=self.dropout_p, - is_causal=self.is_causal, - scale=self.scale, - ) - - @staticmethod - def get_input_tensors(mask_rank: int, dtype: torch.dtype = torch.float32): - batch_size = 8 - heads = 16 - seq_len = 32 - dim = 64 - - q = torch.randn(batch_size, heads, seq_len, dim).to(dtype) - k = torch.randn(batch_size, heads, seq_len, dim).to(dtype) - v = torch.randn(batch_size, heads, seq_len, dim).to(dtype) - - mask = None - if mask_rank > 0: - assert mask_rank >= 2, "mask rank must be >= 2" - mask = torch.full((seq_len, seq_len), 0, dtype=dtype) - while mask.ndim < mask_rank: - mask.unsqueeze_(0) - - return (q, k, v, mask) - - def _test(self, module, inputs, atol=1e-03, rtol=1e-03): - module = module.eval() - ( - Tester(module, inputs) - .export() - .to_edge_transform_and_lower( - ToEdgeTransformAndLower([XnnpackPartitioner(configs=[SDPAConfig])]) - ) - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .check_not( - ["executorch_exir_dialects_edge__ops_aten_bmm_default"], - ) - .to_executorch() - .serialize() - .run_method_and_compare_outputs(atol=atol, rtol=rtol) - ) - - def test_fp16_sdpa_mask2d(self): - """ - Tests that the SDPA operator is correctly lowered to XNNPACK - """ - module = self.SDPA() - inputs = module.get_input_tensors(mask_rank=2, dtype=torch.float16) - self._test(module, inputs, atol=1e-02, rtol=1e-02) - - def test_fp32_sdpa_mask2d(self): - """ - Tests that the SDPA operator is correctly lowered to XNNPACK - """ - module = self.SDPA() - inputs = module.get_input_tensors(mask_rank=2) - self._test(module, inputs) - - def test_fp16_sdpa_userscale(self): - """ - Tests that the scale parameter is passed correctly to the SDPA operator - """ - module = self.SDPA(scale=0.1234) - inputs = module.get_input_tensors(mask_rank=2, dtype=torch.float16) - self._test(module, inputs, atol=1e-02, rtol=1e-02) - - def test_fp32_sdpa_userscale(self): - """ - Tests that the scale parameter is passed correctly to the SDPA operator - """ - module = self.SDPA(scale=0.1234) - inputs = module.get_input_tensors(mask_rank=2) - self._test(module, inputs) - - @unittest.expectedFailure - def test_fp32_sdpa_nomask(self): - module = self.SDPA() - inputs = module.get_input_tensors(mask_rank=0) - # AssertionError: SubgraphMatcher cannot be initialized with an pattern with dead code - # This is from attn_mask=None arg - self._test(module, inputs) - - @unittest.expectedFailure - def test_fp32_sdpa_mask4d(self): - """ - Tests that the scale parameter is passed correctly to the SDPA operator - """ - module = self.SDPA(scale=0.1234) - # can't mask.squeeze_(0) yet with xnnpack - inputs = module.get_input_tensors(mask_rank=4) - self._test(module, inputs)