Skip to content

Commit 5e5db71

Browse files
authored
Make torchao pt2e prepare/convert functions compatible with quantizers in torch.ao (#2221)
* lint * up * up * up * lint
1 parent 5549da8 commit 5e5db71

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

torchao/quantization/pt2e/quantize_pt2e.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,15 @@ def calibrate(model, data_loader):
9595
# run calibration
9696
# calibrate(m, sample_inference_data)
9797
"""
98+
# We will temporarily make prepare_pt2e backward compatible with quantizers that configs, observers,
99+
# and fake quantizers from torch.ao instead of torchao
100+
if isinstance(quantizer, torch.ao.quantization.quantizer.quantizer.Quantizer):
101+
from torch.ao.quantization.quantize_pt2e import (
102+
prepare_pt2e as torch_prepare_pt2e,
103+
)
104+
105+
return torch_prepare_pt2e(model, quantizer)
106+
98107
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_pt2e")
99108
original_graph_meta = model.meta
100109
node_name_to_scope = _get_node_name_to_scope(model)
@@ -172,6 +181,15 @@ def train_loop(model, train_data):
172181
train_loop(prepared_model, train_loop)
173182
174183
"""
184+
# We will temporarily make prepare_qat_pt2e backward compatible with quantizers that configs, observers,
185+
# and fake quantizers from torch.ao instead of torchao
186+
if isinstance(quantizer, torch.ao.quantization.quantizer.quantizer.Quantizer):
187+
from torch.ao.quantization.quantize_pt2e import (
188+
prepare_qat_pt2e as torch_prepare_qat_pt2e,
189+
)
190+
191+
return torch_prepare_qat_pt2e(model, quantizer)
192+
175193
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_qat_pt2e")
176194
original_graph_meta = model.meta
177195
node_name_to_scope = _get_node_name_to_scope(model)
@@ -217,6 +235,43 @@ def _quant_node_constraint(n: Node) -> bool:
217235
return n.op == "call_function" and n.target in _QUANT_OPS
218236

219237

238+
def _is_torchao_prepared_do_not_use_outside_this_file(model):
239+
from torchao.quantization.pt2e.fake_quantize import (
240+
FakeQuantize as torchao_FakeQuantize,
241+
)
242+
from torchao.quantization.pt2e.observer import (
243+
AffineQuantizedObserverBase as torchao_AffineQuantizedObserverBase,
244+
)
245+
from torchao.quantization.pt2e.observer import ObserverBase as torchao_ObserverBase
246+
247+
is_torch_ao_prepared = False
248+
is_torchao_prepared = False
249+
for _, m in model.named_modules():
250+
if (
251+
isinstance(m, torch.ao.quantization.fake_quantize.FakeQuantize)
252+
or isinstance(m, torch.ao.quantization.observer.ObserverBase)
253+
or isinstance(m, torch.ao.quantization.observer.AffineQuantizedObserverBase)
254+
):
255+
is_torch_ao_prepared = True
256+
if (
257+
isinstance(m, torchao_FakeQuantize)
258+
or isinstance(m, torchao_ObserverBase)
259+
or isinstance(m, torchao_AffineQuantizedObserverBase)
260+
):
261+
is_torchao_prepared = True
262+
263+
if is_torch_ao_prepared:
264+
assert not is_torchao_prepared, (
265+
"Cannot be prepared using both torch.ao and torchao"
266+
)
267+
if is_torchao_prepared:
268+
assert not is_torch_ao_prepared, (
269+
"Cannot be prepared using both torch.ao and torchao"
270+
)
271+
272+
return is_torchao_prepared
273+
274+
220275
def convert_pt2e(
221276
model: GraphModule,
222277
use_reference_representation: bool = False,
@@ -243,6 +298,15 @@ def convert_pt2e(
243298
quantized_model = convert_pt2e(prepared_model)
244299
245300
"""
301+
# We will temporarily make convert_pt2e backward compatible with quantizers that configs, observers,
302+
# and fake quantizers from torch.ao instead of torchao
303+
if not _is_torchao_prepared_do_not_use_outside_this_file(model):
304+
from torch.ao.quantization.quantize_pt2e import (
305+
convert_pt2e as torch_convert_pt2e,
306+
)
307+
308+
return torch_convert_pt2e(model, use_reference_representation, fold_quantize)
309+
246310
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.convert_pt2e")
247311
if not isinstance(use_reference_representation, bool):
248312
raise ValueError(

0 commit comments

Comments
 (0)