Skip to content

Commit 55e5d40

Browse files
committed
[quant] Prototype for unified quantization API
Summary: To reduce the mental overhead for user, we would like to unify quantization flow for different quantization methods, otherwise, if we offer different APIs for different quantization methods, it will be hard for users to learn and use since they need to be familiar with API for each methods and know how to use each one. e.g. dynamic quant: API flow 1 weight only quant: API flow 2 GPTQ: API flow 3 static quant: API flow 4 QAT: API flow 5 if each one has their own flow, e.g. some of them have one line API, others have multiple lines with calibration/training etc., it will be a bad UX since user need to remember the flow for each one or constantly going back to the tutorial to use these APIs Instead: we'd like to have a unified quantization API and flow, initial plan: ``` class Quantizer(): def quantize(self, model: torch.nn.Module) -> torch.nn.Module: pass class TwoStepQuantizer(): # Note: each Quantizer will have their own implementation for prepare and convert def prepare(self, model: torch.nn.Module) -> torch.nn.Module: # implementation 1 # model = prepare_pt2e(model, self) # implementation 2, module swap, modifying weights with tensor subclass etc. # model = ... ... return model def convert(self, model: torch.nn.Module) -> torch.nn.Module: # implementation 1 # model = convert_pt2e(model, self) ... # implementation 2 # model = ... return model def save(self, model: torch.nn.Module, *args, **kwargs) -> None: pass def load(self, *args, **kwargs) -> torch.nn.Module: pass class ExportQuantizer(TwoStepQuantizer): ... def annotate(self, model: fx.GraphModule) -> None: # [optional] used only in export based flow ... def prepare(self, model: fx.GraphModule) -> fx.GraphModule: ... class XNNPACKQuantizer(ExportQuantizer): def annotate(...): ... captured_model = capture(eager_model) quantizer = XNNPACKQuantizer(captured_model) model = quantizer.prepare(model) model = quantizer.convert(model) captured_model = capture(eager_model) quantizer = Quantizer(captured_model, is_qat=True) model = quantizer.prepare(model) model = quantizer.convert(model) class GPTQQuantizer(Quantizer): ... def quantize(...): ... def convert(...): ... quantizer = GPTQQuantizer() model = quantizer.quantize(eager_model) torch.save(model.state_dict(), "gptq_weights.pt") quantizer = GPTQQuantizer(load_time=True) model = quantizer.quantize(eager_model) model.load_state_dict(torch.load("gptq_weights.pt")) class DynamicQuantizer(Quantizer): ... def quantize(...): ... quantizer = DynamicQuantizer() model = quantizer.quantize(eager_model) class WeightOnlyQuantizer(Quantizer): ... def quantize(...): ... quantizer = WeightOnlyQuantizer() model = quantizer.quantize(eager_model) ``` Test Plan: python test/quantization/test_quant_api.py Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: e34a9d7 Pull Request resolved: #17
1 parent c9b397d commit 55e5d40

File tree

2 files changed

+152
-1
lines changed

2 files changed

+152
-1
lines changed

test/quantization/test_quant_api.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# mypy: ignore-errors
8+
import unittest
9+
import torch
10+
from torch._export import capture_pre_autograd_graph
11+
from torch.ao.quantization.quantize_pt2e import (
12+
prepare_pt2e,
13+
convert_pt2e,
14+
)
15+
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
16+
XNNPACKQuantizer,
17+
get_symmetric_quantization_config,
18+
)
19+
20+
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
21+
from torchao.quantization.quant_api import apply_dynamic_quant
22+
from torchao.quantization.quant_api import (
23+
Quantizer,
24+
TwoStepQuantizer,
25+
)
26+
27+
def dynamic_quant(model, example_inputs):
28+
m = capture_pre_autograd_graph(model, example_inputs)
29+
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True))
30+
m = prepare_pt2e(m, quantizer)
31+
m = convert_pt2e(m)
32+
return m
33+
34+
def _apply_dynamic_quant(model):
35+
"""
36+
Applies dynamic symmetric per-token activation and per-channel weight
37+
quantization to all linear layers in the given model using
38+
module swaps.
39+
"""
40+
_replace_with_custom_fn_if_matches_filter(
41+
model,
42+
lambda linear_mod: dynamic_quant(linear_mod, (torch.randn(1, linear_mod.in_features))),
43+
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
44+
)
45+
return model
46+
47+
48+
def capture_and_prepare(model, example_inputs):
49+
m = capture_pre_autograd_graph(model, example_inputs)
50+
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True))
51+
m = prepare_pt2e(m, quantizer)
52+
# TODO: we can run the weight observer in convert_pt2e so that user don't need to run this
53+
m(*example_inputs)
54+
return m
55+
56+
class XNNPackDynamicQuantizer(TwoStepQuantizer):
57+
58+
def prepare(self, model: torch.nn.Module) -> torch.nn.Module:
59+
_replace_with_custom_fn_if_matches_filter(
60+
model,
61+
lambda linear_mod: capture_and_prepare(linear_mod, (torch.randn(1, linear_mod.in_features))),
62+
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
63+
)
64+
return model
65+
66+
def convert(self, model: torch.nn.Module) -> torch.nn.Module:
67+
_replace_with_custom_fn_if_matches_filter(
68+
model,
69+
lambda linear_mod: convert_pt2e(linear_mod),
70+
lambda mod, fqn: isinstance(mod, torch.fx.GraphModule),
71+
)
72+
return model
73+
74+
class TorchCompileDynamicQuantizer(Quantizer):
75+
def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
76+
apply_dynamic_quant(model)
77+
return model
78+
79+
class M(torch.nn.Module):
80+
def __init__(self):
81+
super().__init__()
82+
self.linear1 = torch.nn.Linear(5, 5).to(torch.float)
83+
self.linear2 = torch.nn.Linear(5, 5).to(torch.float)
84+
85+
def forward(self, x):
86+
x = self.linear1(x)
87+
x = self.linear2(x)
88+
return x
89+
90+
class TestQuantFlow(unittest.TestCase):
91+
def test_dynamic_quant_gpu_singleline(self):
92+
m = M().eval()
93+
m = _apply_dynamic_quant(m)
94+
example_inputs = (torch.randn(1, 5).to(dtype=torch.float32),)
95+
quantized = m(*example_inputs)
96+
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
97+
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
98+
# m = torch.compile(m, mode="max-autotune")
99+
# print(example_inputs[0].dtype)
100+
# compiled = m(*example_inputs)
101+
# torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)
102+
103+
@unittest.skip("skipping for now due to torch.compile error")
104+
def test_dynamic_quant_gpu_unified_api_unified_impl(self):
105+
quantizer = XNNPackDynamicQuantizer()
106+
m = M().eval()
107+
m = quantizer.prepare(m)
108+
m = quantizer.convert(m)
109+
example_inputs = (torch.randn(1, 5).to(dtype=torch.float32),)
110+
quantized = m(*example_inputs)
111+
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
112+
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
113+
m = torch.compile(m, mode="max-autotune")
114+
# print(example_inputs[0].dtype)
115+
compiled = m(*example_inputs)
116+
torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)
117+
118+
def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
119+
quantizer = TorchCompileDynamicQuantizer()
120+
m = M().eval()
121+
m = quantizer.quantize(m)
122+
example_inputs = (torch.randn(1, 5).to(dtype=torch.float32),)
123+
quantized = m(*example_inputs)
124+
m = torch.compile(m, mode="max-autotune")
125+
compiled = m(*example_inputs)
126+
torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)
127+
128+
def test_gptq(self):
129+
# should be similar to TorchCompileDynamicQuantizer
130+
pass
131+
132+
if __name__ == "__main__":
133+
unittest.main()

torchao/quantization/quant_api.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,27 @@
3535
"change_linear_weights_to_int8_dqtensors",
3636
"change_linear_weights_to_int8_woqtensors",
3737
"change_linear_weights_to_int4_woqtensors",
38-
"swap_conv2d_1x1_to_linear"
38+
"swap_conv2d_1x1_to_linear",
39+
"Quantizer",
40+
"TwoStepQuantizer",
3941
]
4042

43+
############################# Unified Quantization APIs ##############################
44+
# API 1, single quantize call to create a quantized model with quantized state_dict
45+
class Quantizer:
46+
def quantize(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
47+
pass
48+
49+
50+
# API 2, flow that needs calibration or training
51+
class TwoStepQuantizer:
52+
def prepare(self, model: torch.nn.Module) -> torch.nn.Module:
53+
pass
54+
55+
def convert(self, model: torch.nn.Module) -> torch.nn.Module:
56+
pass
57+
58+
############################# Unified Quantization APIs ##############################
4159

4260
def _replace_with_custom_fn_if_matches_filter(
4361
model, replacement_fn, filter_fn, cur_fqn=""

0 commit comments

Comments
 (0)