Skip to content

Commit 0b62f3f

Browse files
authored
Replace the "quantization_annotation" string with a constant variable (#2525)
Replace the "quantization_annotation" string with a constant variable (#2525) Summary: Create a const variable `Q_ANNOTATION_KEY` to avoid manually typing `"quantization_annotation"` which can be error prone Differential Revision: D78133734
1 parent c663e30 commit 0b62f3f

File tree

7 files changed

+35
-46
lines changed

7 files changed

+35
-46
lines changed

torchao/quantization/pt2e/prepare.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
from torch._subclasses import FakeTensor
1414
from torch.ao.quantization import QConfigMapping
1515
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
16-
from torch.ao.quantization.fx.prepare import (
17-
_insert_obs_or_fq,
18-
_save_state,
19-
)
16+
from torch.ao.quantization.fx.prepare import _insert_obs_or_fq, _save_state
2017
from torch.ao.quantization.qconfig import QConfigAny
2118
from torch.fx import Graph, GraphModule, Node
2219
from torch.fx.node import Argument
@@ -26,9 +23,7 @@
2623
DerivedObserverOrFakeQuantize,
2724
ObserverOrFakeQuantize,
2825
)
29-
from torchao.quantization.pt2e.fake_quantize import (
30-
FixedQParamsFakeQuantize,
31-
)
26+
from torchao.quantization.pt2e.fake_quantize import FixedQParamsFakeQuantize
3227
from torchao.quantization.pt2e.observer import (
3328
FixedQParamsObserver,
3429
PartialWrapper,
@@ -42,6 +37,7 @@
4237
QuantizationSpecBase,
4338
SharedQuantizationSpec,
4439
)
40+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
4541
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6
4642

4743
# TODO: make pt2e folder private?
@@ -208,8 +204,8 @@ def _get_edge_or_node_to_qspec(
208204
"""Get a map from EdgeOrNode to quantization spec based on annotations on the nodes"""
209205
edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase] = {}
210206
for n in model.graph.nodes:
211-
if hasattr(n, "meta") and "quantization_annotation" in n.meta:
212-
qa = n.meta["quantization_annotation"]
207+
if hasattr(n, "meta") and Q_ANNOTATION_KEY in n.meta:
208+
qa = n.meta[Q_ANNOTATION_KEY]
213209
for input_to_n, qspec in qa.input_qspec_map.items():
214210
input_edge = (input_to_n, n)
215211
edge_or_node_to_qspec[input_edge] = qspec
@@ -324,7 +320,7 @@ def _get_edge_or_node_to_group_id(
324320

325321
assert isinstance(input_edge, tuple)
326322
arg, n = input_edge
327-
if n.meta["quantization_annotation"].allow_implicit_sharing:
323+
if n.meta[Q_ANNOTATION_KEY].allow_implicit_sharing:
328324
# NOTE: the order is important here, we first share with other users and then share with previous
329325
# output because the reverse order could cause circular dependency
330326
# e.g node1 -> node2
@@ -571,9 +567,7 @@ def _maybe_insert_input_and_output_observers_for_node(
571567
is_qat: bool,
572568
):
573569
this_node_quantization_annotation = (
574-
node.meta["quantization_annotation"]
575-
if "quantization_annotation" in node.meta
576-
else None
570+
node.meta[Q_ANNOTATION_KEY] if Q_ANNOTATION_KEY in node.meta else None
577571
)
578572
if this_node_quantization_annotation is None:
579573
return

torchao/quantization/pt2e/quantizer/composable_quantizer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from typing import TYPE_CHECKING
1010

11+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
12+
1113
from .quantizer import QuantizationAnnotation, Quantizer
1214

1315
if TYPE_CHECKING:
@@ -48,18 +50,17 @@ def _record_and_validate_annotations(
4850
self, gm: torch.fx.GraphModule, quantizer: Quantizer
4951
) -> None:
5052
for n in gm.graph.nodes:
51-
if "quantization_annotation" in n.meta:
53+
if Q_ANNOTATION_KEY in n.meta:
5254
# check if the annotation has been changed by
5355
# comparing QuantizationAnnotation object id
5456
if n in self._graph_annotations and (
55-
id(self._graph_annotations[n])
56-
!= id(n.meta["quantization_annotation"])
57+
id(self._graph_annotations[n]) != id(n.meta[Q_ANNOTATION_KEY])
5758
):
5859
raise RuntimeError(
5960
f"Quantizer {quantizer.__class__.__name__} has changed annotations on node {n}"
6061
)
6162
else:
62-
self._graph_annotations[n] = n.meta["quantization_annotation"]
63+
self._graph_annotations[n] = n.meta[Q_ANNOTATION_KEY]
6364
else:
6465
if n in self._graph_annotations:
6566
raise RuntimeError(

torchao/quantization/pt2e/quantizer/duplicate_dq_pass.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,10 @@
1212
from torch.fx.node import map_arg
1313
from torch.fx.passes.infra.pass_base import PassBase, PassResult
1414

15-
from torchao.quantization.pt2e.utils import (
16-
_filter_sym_size_users,
17-
)
15+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
16+
from torchao.quantization.pt2e.utils import _filter_sym_size_users
1817

19-
from .utils import (
20-
is_valid_annotation,
21-
)
18+
from .utils import is_valid_annotation
2219

2320
logger = logging.getLogger(__name__)
2421
logger.setLevel(logging.WARNING)
@@ -41,7 +38,7 @@
4138
def _maybe_duplicate_dq(
4239
gm: torch.fx.GraphModule, dq_node: torch.fx.Node, user: torch.fx.Node
4340
):
44-
annotation = user.meta.get("quantization_annotation", None)
41+
annotation = user.meta.get(Q_ANNOTATION_KEY, None)
4542
if not is_valid_annotation(annotation):
4643
return
4744
with gm.graph.inserting_after(dq_node):

torchao/quantization/pt2e/quantizer/embedding_quantizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
QuantizationSpec,
2222
Quantizer,
2323
)
24+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
2425

2526
__all__ = [
2627
"get_embedding_operators_config",
@@ -87,7 +88,7 @@ def _annotate_embedding_ops(self, graph: torch.fx.Graph) -> None:
8788
raise ValueError(
8889
"Embedding config must have a valid weight quantization spec."
8990
)
90-
node.meta["quantization_annotation"] = QuantizationAnnotation(
91+
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
9192
input_qspec_map={
9293
node.args[0]: embedding_config.config.weight,
9394
}

torchao/quantization/pt2e/quantizer/port_metadata_pass.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,13 @@
1212
from torch._export.error import InternalError
1313
from torch.fx.passes.infra.pass_base import PassBase, PassResult
1414

15-
from torchao.quantization.pt2e.utils import (
16-
_filter_sym_size_users,
17-
)
15+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
16+
from torchao.quantization.pt2e.utils import _filter_sym_size_users
1817
from torchao.quantization.quant_primitives import quant_lib # noqa: F401
1918
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
2019

21-
from .quantizer import (
22-
QuantizationSpecBase,
23-
)
24-
from .utils import (
25-
is_valid_annotation,
26-
)
20+
from .quantizer import QuantizationSpecBase
21+
from .utils import is_valid_annotation
2722

2823
logger = logging.getLogger(__name__)
2924
logger.setLevel(logging.ERROR)
@@ -68,7 +63,7 @@ def _add_metadata(to_node: torch.fx.Node, from_node: torch.fx.Node) -> None:
6863

6964

7065
def _has_quant_annotation(node: torch.fx.Node) -> bool:
71-
return "quantization_annotation" in node.meta
66+
return Q_ANNOTATION_KEY in node.meta
7267

7368

7469
def _find_choose_qparams_node(node: torch.fx.Node) -> Optional[torch.fx.Node]:
@@ -281,10 +276,10 @@ class PortNodeMetaForQDQ(PassBase):
281276

282277
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
283278
for node in graph_module.graph.nodes:
284-
annotation = node.meta.get("quantization_annotation", None)
279+
annotation = node.meta.get(Q_ANNOTATION_KEY, None)
285280
if is_valid_annotation(annotation):
286-
input_qspec_map = node.meta["quantization_annotation"].input_qspec_map
287-
output_qspec = node.meta["quantization_annotation"].output_qspec
281+
input_qspec_map = node.meta[Q_ANNOTATION_KEY].input_qspec_map
282+
output_qspec = node.meta[Q_ANNOTATION_KEY].output_qspec
288283
for input_node, qspec in input_qspec_map.items():
289284
_port_metadata_for_input_quant_nodes(input_node, node, qspec)
290285
_port_metadata_for_output_quant_nodes(node, output_qspec)

torchao/quantization/pt2e/quantizer/quantizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
]
3131

3232

33+
Q_ANNOTATION_KEY = "quantization_annotation"
34+
35+
3336
class QuantizationSpecBase(ABC): # noqa: B024
3437
"""Base class for different types of quantization specs that allows users to
3538
specify how to quantize a Tensor (input/output of a Node) in the model

torchao/quantization/pt2e/quantizer/utils.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import torch
1414
from torch.fx import Node
1515

16+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
17+
1618
from .quantizer import QuantizationAnnotation, QuantizationSpec
1719

1820

@@ -103,21 +105,17 @@ def get_bias_qspec(quantization_config: Optional[QuantizationConfig]):
103105

104106

105107
def annotate_input_qspec_map(node: Node, input_node: Node, qspec):
106-
quantization_annotation = node.meta.get(
107-
"quantization_annotation", QuantizationAnnotation()
108-
)
108+
quantization_annotation = node.meta.get(Q_ANNOTATION_KEY, QuantizationAnnotation())
109109
if quantization_annotation.input_qspec_map is None:
110110
quantization_annotation.input_qspec_map = {}
111111
quantization_annotation.input_qspec_map[input_node] = qspec
112-
node.meta["quantization_annotation"] = quantization_annotation
112+
node.meta[Q_ANNOTATION_KEY] = quantization_annotation
113113

114114

115115
def annotate_output_qspec(node: Node, qspec):
116-
quantization_annotation = node.meta.get(
117-
"quantization_annotation", QuantizationAnnotation()
118-
)
116+
quantization_annotation = node.meta.get(Q_ANNOTATION_KEY, QuantizationAnnotation())
119117
quantization_annotation.output_qspec = qspec
120-
node.meta["quantization_annotation"] = quantization_annotation
118+
node.meta[Q_ANNOTATION_KEY] = quantization_annotation
121119

122120

123121
def get_module_name_filter(module_name: str):

0 commit comments

Comments
 (0)