@@ -95,6 +95,15 @@ def calibrate(model, data_loader):
95
95
# run calibration
96
96
# calibrate(m, sample_inference_data)
97
97
"""
98
+ # We will temporarily make prepare_pt2e backward compatible with quantizers that configs, observers,
99
+ # and fake quantizers from torch.ao instead of torchao
100
+ if isinstance (quantizer , torch .ao .quantization .quantizer .quantizer .Quantizer ):
101
+ from torch .ao .quantization .quantize_pt2e import (
102
+ prepare_pt2e as torch_prepare_pt2e ,
103
+ )
104
+
105
+ return torch_prepare_pt2e (model , quantizer )
106
+
98
107
torch ._C ._log_api_usage_once ("quantization_api.quantize_pt2e.prepare_pt2e" )
99
108
original_graph_meta = model .meta
100
109
node_name_to_scope = _get_node_name_to_scope (model )
@@ -172,6 +181,15 @@ def train_loop(model, train_data):
172
181
train_loop(prepared_model, train_loop)
173
182
174
183
"""
184
+ # We will temporarily make prepare_qat_pt2e backward compatible with quantizers that configs, observers,
185
+ # and fake quantizers from torch.ao instead of torchao
186
+ if isinstance (quantizer , torch .ao .quantization .quantizer .quantizer .Quantizer ):
187
+ from torch .ao .quantization .quantize_pt2e import (
188
+ prepare_qat_pt2e as torch_prepare_qat_pt2e ,
189
+ )
190
+
191
+ return torch_prepare_qat_pt2e (model , quantizer )
192
+
175
193
torch ._C ._log_api_usage_once ("quantization_api.quantize_pt2e.prepare_qat_pt2e" )
176
194
original_graph_meta = model .meta
177
195
node_name_to_scope = _get_node_name_to_scope (model )
@@ -217,6 +235,43 @@ def _quant_node_constraint(n: Node) -> bool:
217
235
return n .op == "call_function" and n .target in _QUANT_OPS
218
236
219
237
238
+ def _is_torchao_prepared_do_not_use_outside_this_file (model ):
239
+ from torchao .quantization .pt2e .fake_quantize import (
240
+ FakeQuantize as torchao_FakeQuantize ,
241
+ )
242
+ from torchao .quantization .pt2e .observer import (
243
+ AffineQuantizedObserverBase as torchao_AffineQuantizedObserverBase ,
244
+ )
245
+ from torchao .quantization .pt2e .observer import ObserverBase as torchao_ObserverBase
246
+
247
+ is_torch_ao_prepared = False
248
+ is_torchao_prepared = False
249
+ for _ , m in model .named_modules ():
250
+ if (
251
+ isinstance (m , torch .ao .quantization .fake_quantize .FakeQuantize )
252
+ or isinstance (m , torch .ao .quantization .observer .ObserverBase )
253
+ or isinstance (m , torch .ao .quantization .observer .AffineQuantizedObserverBase )
254
+ ):
255
+ is_torch_ao_prepared = True
256
+ if (
257
+ isinstance (m , torchao_FakeQuantize )
258
+ or isinstance (m , torchao_ObserverBase )
259
+ or isinstance (m , torchao_AffineQuantizedObserverBase )
260
+ ):
261
+ is_torchao_prepared = True
262
+
263
+ if is_torch_ao_prepared :
264
+ assert not is_torchao_prepared , (
265
+ "Cannot be prepared using both torch.ao and torchao"
266
+ )
267
+ if is_torchao_prepared :
268
+ assert not is_torch_ao_prepared , (
269
+ "Cannot be prepared using both torch.ao and torchao"
270
+ )
271
+
272
+ return is_torchao_prepared
273
+
274
+
220
275
def convert_pt2e (
221
276
model : GraphModule ,
222
277
use_reference_representation : bool = False ,
@@ -243,6 +298,15 @@ def convert_pt2e(
243
298
quantized_model = convert_pt2e(prepared_model)
244
299
245
300
"""
301
+ # We will temporarily make convert_pt2e backward compatible with quantizers that configs, observers,
302
+ # and fake quantizers from torch.ao instead of torchao
303
+ if not _is_torchao_prepared_do_not_use_outside_this_file (model ):
304
+ from torch .ao .quantization .quantize_pt2e import (
305
+ convert_pt2e as torch_convert_pt2e ,
306
+ )
307
+
308
+ return torch_convert_pt2e (model , use_reference_representation , fold_quantize )
309
+
246
310
torch ._C ._log_api_usage_once ("quantization_api.quantize_pt2e.convert_pt2e" )
247
311
if not isinstance (use_reference_representation , bool ):
248
312
raise ValueError (
0 commit comments