From 021f95d01ef6db9614c1d17a0fe6e20706ef4000 Mon Sep 17 00:00:00 2001 From: Aaron Ang <67321817+aaron-ang@users.noreply.github.com> Date: Fri, 13 Jun 2025 16:51:57 -0700 Subject: [PATCH] Only support int8 dtype for quant operators --- backends/xnnpack/partition/config/xnnpack_config.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/backends/xnnpack/partition/config/xnnpack_config.py b/backends/xnnpack/partition/config/xnnpack_config.py index df6067a7d68..e0221c27d44 100644 --- a/backends/xnnpack/partition/config/xnnpack_config.py +++ b/backends/xnnpack/partition/config/xnnpack_config.py @@ -10,12 +10,13 @@ from typing import List, Optional import torch +from torch.export import ExportedProgram +from executorch.backends.xnnpack.utils.quant_utils import is_dynamic_qdq from executorch.exir.backend.canonical_partitioners.config_partitioner import ( format_target_name, PartitionerConfig, ) from executorch.exir.backend.utils import WhyNoPartition -from torch.export import ExportedProgram logger = logging.getLogger(__name__) why = WhyNoPartition(logger=logger) @@ -220,9 +221,12 @@ def _check_node_has_valid_dtype(self, node): valid_dtypes = { torch.float32, torch.float16, - torch.int8, torch.qint8, } + # Only allow int8 for quantization operations + if is_dynamic_qdq(node): + valid_dtypes.add(torch.int8) + if ( node.op != "placeholder" and node.op != "call_function"