Skip to content

Commit 2c7dbbd

Browse files
authored
[quantizer] fix add_observer attribute error of torch_q (#220)
1 parent 1e4ed63 commit 2c7dbbd

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

tinynn/graph/quantization/quantizer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3657,6 +3657,11 @@ def prepare_qat(
36573657
if hasattr(n.module, "qconfig"):
36583658
delattr(n.module, "qconfig")
36593659

3660+
if hasattr(torch_q, 'add_observer_'):
3661+
add_observer_func = torch_q.add_observer_
3662+
else:
3663+
add_observer_func = sys.modules['torch.ao.quantization.quantize']._add_observer_
3664+
36603665
if LooseVersion(torch.__version__) >= LooseVersion("1.8.0"):
36613666
if LooseVersion(torch.__version__) >= LooseVersion("1.13.0"):
36623667
prepare_custom_config_dict = torch.ao.quantization.get_default_custom_config_dict()
@@ -3667,13 +3672,13 @@ def prepare_qat(
36673672
"float_to_observed_custom_module_class", {}
36683673
)
36693674

3670-
torch_q.add_observer_(
3675+
add_observer_func(
36713676
graph.module,
36723677
qconfig_propagation_list=whitelist,
36733678
custom_module_class_mapping=custom_module_class_mapping,
36743679
)
36753680
else:
3676-
torch_q.add_observer_(
3681+
add_observer_func(
36773682
graph.module,
36783683
qconfig_propagation_list=whitelist,
36793684
)

0 commit comments

Comments
 (0)