Skip to content

Commit 92fb0cc

Browse files
authored
Arm backend: Add function extract_io_quant_params (#12481)
Summary: Add function to return quant params for lowered graph and remove these Q/DQ from the graph. If they are needed, then the EdgeProgramManager should be copied before use of this function. Signed-off-by: Elena Zhelezina <elena.zhelezina@arm.com>
1 parent 8da2ea6 commit 92fb0cc

File tree

2 files changed

+190
-1
lines changed

2 files changed

+190
-1
lines changed

exir/passes/quantize_io_pass.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree
67

8+
#
9+
# This source code is licensed under the BSD-style license found in the
10+
# LICENSE file in the root directory of this source tree.
11+
712
import logging
8-
from typing import Any, Dict, List, Optional, Union
13+
from typing import Any, Dict, List, Optional, Sequence, Union
914

1015
import numpy as np
1116

1217
import torch
18+
import torch.fx as fx
1319

1420
from executorch.exir import EdgeProgramManager, ExportedProgram
1521
from executorch.exir.dialects._ops import ops as exir_ops
@@ -316,3 +322,93 @@ def call(self, graph_module: torch.fx.GraphModule):
316322
self.edge_manager_update_quant_config_method(i, self.dequant_args[i])
317323

318324
return PassResult(graph_module, True)
325+
326+
327+
def extract_io_quant_params(
328+
edge_prog: EdgeProgramManager,
329+
*,
330+
input_idxs: Sequence[int] = (0,),
331+
output_idxs: Sequence[int] = (0,),
332+
) -> Dict[str, Dict[str, Dict[str, Any]]]:
333+
"""
334+
Returns quantization parameters such as scale/zero_point:
335+
{
336+
"inputs": {
337+
<placeholder_name>: {"scale": float, "zero_point": int}
338+
},
339+
"outputs": {
340+
<node_name>: {"scale": float, "zero_point": int}
341+
}
342+
}
343+
344+
Note that this function will strip out the IO quantize/dequantize ops as
345+
it records their parameters, so if you need to preserve the original graph
346+
you need to make a copy with copy.deepcopy before.
347+
348+
Note that `to_edge_transform_and_lower` should be called before.
349+
"""
350+
# Use IO passes
351+
passes = []
352+
for idx in input_idxs:
353+
passes.append(QuantizeInputs(edge_prog, [idx]))
354+
for idx in output_idxs:
355+
passes.append(QuantizeOutputs(edge_prog, [idx]))
356+
357+
# Apply them
358+
edge_prog = edge_prog.transform(passes)
359+
360+
cfg = getattr(edge_prog, "_config_methods", {}) or {}
361+
362+
# We need GraphModule to find node names
363+
gm = edge_prog.exported_program().graph_module
364+
365+
input_names = _gather_io_names(gm, side="input")
366+
output_names = _gather_io_names(gm, side="output")
367+
368+
# Build the result dict
369+
result = {"inputs": {}, "outputs": {}}
370+
for key, val in cfg.items():
371+
if key.startswith("input"):
372+
prefix, section, names = "input", "inputs", input_names
373+
elif key.startswith("output"):
374+
prefix, section, names = "output", "outputs", output_names
375+
else:
376+
continue
377+
378+
idx_str, param = key[len(prefix) :].split("_", 1)
379+
idx = int(idx_str)
380+
name = names[idx]
381+
# We need to map 'zp' to 'zero_point'
382+
out_param = "zero_point" if param in ("zp", "zero_point") else param
383+
result[section].setdefault(name, {})[out_param] = val
384+
385+
return result
386+
387+
388+
def _gather_io_names(gm: fx.GraphModule, side: str):
389+
"""
390+
For 'input', returns placeholder names in graph order.
391+
For 'output', returns names of output nodes.
392+
"""
393+
if side == "input":
394+
return [n.name for n in gm.graph.nodes if n.op == "placeholder"]
395+
396+
if side == "output":
397+
398+
def _flatten(args):
399+
out = []
400+
401+
def rec(x):
402+
if isinstance(x, (tuple, list)):
403+
for y in x:
404+
rec(y)
405+
elif isinstance(x, fx.Node):
406+
out.append(x)
407+
408+
rec(args)
409+
return out
410+
411+
output_node = next(n for n in gm.graph.nodes if n.op == "output")
412+
return [n.name for n in _flatten(output_node.args)]
413+
414+
raise ValueError(f"Unknown side: {side}")
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import copy
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
11+
get_symmetric_quantization_config,
12+
XNNPACKQuantizer,
13+
)
14+
from executorch.exir import to_edge_transform_and_lower
15+
from executorch.exir.passes.quantize_io_pass import extract_io_quant_params
16+
17+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
18+
19+
20+
class SimpleAdd(torch.nn.Module):
21+
def forward(self, x, y):
22+
return x + y
23+
24+
25+
class TestExtractIOQuantParamsPT2E(unittest.TestCase):
26+
def setUp(self):
27+
self.example_inputs = (
28+
torch.ones(1, 5),
29+
torch.full(
30+
(
31+
1,
32+
5,
33+
),
34+
2.0,
35+
),
36+
)
37+
self.mod = SimpleAdd().eval()
38+
39+
# Setup XNNPACK quantizer for example
40+
self.quantizer = XNNPACKQuantizer()
41+
operator_config = get_symmetric_quantization_config()
42+
self.quantizer.set_global(operator_config)
43+
44+
exported = torch.export.export_for_training(
45+
self.mod,
46+
copy.deepcopy(self.example_inputs),
47+
strict=True,
48+
)
49+
prepared = prepare_pt2e(exported.module(), self.quantizer)
50+
51+
# Call observers to calibrate
52+
_ = prepared(*self.example_inputs)
53+
54+
converted = convert_pt2e(prepared)
55+
56+
# Export again with quant parameters
57+
final_export = torch.export.export_for_training(
58+
converted,
59+
self.example_inputs,
60+
strict=True,
61+
)
62+
63+
# Lower to EdgeProgramManager
64+
self.edge_prog = to_edge_transform_and_lower(final_export)
65+
66+
def test_roundtrip_extracts_io_params(self):
67+
# Get dict with quant parameters
68+
q = extract_io_quant_params(
69+
self.edge_prog,
70+
input_idxs=(0, 1),
71+
output_idxs=(0,),
72+
)
73+
74+
# Validate structure
75+
self.assertIn("inputs", q)
76+
self.assertIn("outputs", q)
77+
self.assertEqual(len(q["inputs"]), 2)
78+
self.assertEqual(len(q["outputs"]), 1)
79+
80+
# Each entry must have a float 'scale' and int 'zero_point'
81+
for name, params in q["inputs"].items():
82+
self.assertIsInstance(name, str)
83+
self.assertIsInstance(params["scale"], float)
84+
self.assertIsInstance(params["zero_point"], int)
85+
86+
out_name, out_params = next(iter(q["outputs"].items()))
87+
self.assertIsInstance(out_name, str)
88+
self.assertIsInstance(out_params["scale"], float)
89+
self.assertIsInstance(out_params["zero_point"], int)
90+
91+
92+
if __name__ == "__main__":
93+
unittest.main()

0 commit comments

Comments
 (0)