|
20 | 20 | from torch.fx import Node
|
21 | 21 | from torchao.quantization.pt2e import FixedQParamsObserver, MinMaxObserver
|
22 | 22 | from torchao.quantization.pt2e.quantizer import (
|
| 23 | + annotate_input_qspec_map, |
| 24 | + annotate_output_qspec, |
23 | 25 | QuantizationAnnotation,
|
24 | 26 | QuantizationSpec,
|
25 | 27 | SharedQuantizationSpec,
|
@@ -213,6 +215,24 @@ def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None
|
213 | 215 | _annotated=True,
|
214 | 216 | )
|
215 | 217 |
|
| 218 | + def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> None: |
| 219 | + act_node = node.args[0] |
| 220 | + weight_node = node.args[2] |
| 221 | + |
| 222 | + # TODO current only support 16a16w |
| 223 | + annotate_input_qspec_map( |
| 224 | + node, |
| 225 | + act_node, |
| 226 | + quantization_config.input_activation, |
| 227 | + ) |
| 228 | + |
| 229 | + annotate_input_qspec_map( |
| 230 | + node, |
| 231 | + weight_node, |
| 232 | + quantization_config.input_activation, |
| 233 | + ) |
| 234 | + annotate_output_qspec(node, quantization_config.output_activation) |
| 235 | + |
216 | 236 | def annotate_single_in_single_out(
|
217 | 237 | node: Node, quantization_config: QuantizationConfig
|
218 | 238 | ) -> None:
|
@@ -287,6 +307,9 @@ def annotate_matmul_input1(node: Node):
|
287 | 307 | elif node.target == torch.ops.aten.flatten.using_ints:
|
288 | 308 | annotate_single_in_share_out(node, quantization_config_8a8w)
|
289 | 309 | node = node.args[0]
|
| 310 | + elif node.target == torch.ops.aten.rms_norm.default: |
| 311 | + annotate_rms_norm(node, quantization_config_8a8w) |
| 312 | + node = node.args[0] |
290 | 313 | elif node.target == torch.ops.aten.cat.default:
|
291 | 314 | annotate_cat(node, quantization_config_8a8w)
|
292 | 315 | # For v, we tag 8a until conv op.
|
|
0 commit comments