Skip to content

Commit ef3cefe

Browse files
authored
annotate the rms_norm
Differential Revision: D77768317 Pull Request resolved: #12238
1 parent f6bb143 commit ef3cefe

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from torch.fx import Node
2121
from torchao.quantization.pt2e import FixedQParamsObserver, MinMaxObserver
2222
from torchao.quantization.pt2e.quantizer import (
23+
annotate_input_qspec_map,
24+
annotate_output_qspec,
2325
QuantizationAnnotation,
2426
QuantizationSpec,
2527
SharedQuantizationSpec,
@@ -213,6 +215,24 @@ def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None
213215
_annotated=True,
214216
)
215217

218+
def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> None:
219+
act_node = node.args[0]
220+
weight_node = node.args[2]
221+
222+
# TODO current only support 16a16w
223+
annotate_input_qspec_map(
224+
node,
225+
act_node,
226+
quantization_config.input_activation,
227+
)
228+
229+
annotate_input_qspec_map(
230+
node,
231+
weight_node,
232+
quantization_config.input_activation,
233+
)
234+
annotate_output_qspec(node, quantization_config.output_activation)
235+
216236
def annotate_single_in_single_out(
217237
node: Node, quantization_config: QuantizationConfig
218238
) -> None:
@@ -287,6 +307,9 @@ def annotate_matmul_input1(node: Node):
287307
elif node.target == torch.ops.aten.flatten.using_ints:
288308
annotate_single_in_share_out(node, quantization_config_8a8w)
289309
node = node.args[0]
310+
elif node.target == torch.ops.aten.rms_norm.default:
311+
annotate_rms_norm(node, quantization_config_8a8w)
312+
node = node.args[0]
290313
elif node.target == torch.ops.aten.cat.default:
291314
annotate_cat(node, quantization_config_8a8w)
292315
# For v, we tag 8a until conv op.

0 commit comments

Comments
 (0)