Skip to content

Commit e213cc5

Browse files
authored
fix bug in orchest on PT1.12 (#1109)
1 parent 2482190 commit e213cc5

File tree

3 files changed

+46
-29
lines changed

3 files changed

+46
-29
lines changed

neural_compressor/adaptor/pytorch.py

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2717,13 +2717,25 @@ def _pre_hook_for_qat(self):
27172717
qscheme=torch.per_tensor_affine,
27182718
reduce_range=REDUCE_RANGE,
27192719
observer=torch.quantization.MovingAverageMinMaxObserver),
2720-
weight=torch.quantization.default_weight_fake_quant)
2720+
weight=torch.quantization.default_weight_fake_quant) \
2721+
if self.version < PyTorchVersionMode.PT110.value else \
2722+
torch.quantization.QConfig(
2723+
activation=torch.quantization.FusedMovingAvgObsFakeQuantize.with_args(
2724+
dtype=torch.quint8,
2725+
qscheme=torch.per_tensor_affine,
2726+
reduce_range=REDUCE_RANGE),
2727+
weight=torch.quantization.default_fused_per_channel_wt_fake_quant)
27212728
quantizable_ops = []
27222729
tmp_model = self.fuse_fx_model(self.model, is_qat=True)
27232730
self._get_quantizable_ops_recursively(tmp_model, '', quantizable_ops)
27242731
quantized_ops = {op[0]:q_cfgs for op in quantizable_ops}
27252732
if self.version < PyTorchVersionMode.PT111.value:
27262733
quantized_ops["default_qconfig"] = None
2734+
else:
2735+
from torch.ao.quantization import default_embedding_qat_qconfig
2736+
for op in quantizable_ops:
2737+
if op[1] in ['Embedding', 'EmbeddingBag']:
2738+
quantized_ops[op[0]] = default_embedding_qat_qconfig
27272739
from torch.quantization.quantize_fx import prepare_qat_fx
27282740
fx_op_cfgs = _cfgs_to_fx_cfgs(quantized_ops, 'quant_aware_training')
27292741
self.model._model.train()
@@ -3136,35 +3148,35 @@ def fuse_fx_model(self, model, is_qat):
31363148
"""
31373149
try:
31383150
tmp_model = copy.deepcopy(model._model)
3151+
tmp_model.train() if is_qat else tmp_model.eval()
3152+
from torch.fx import GraphModule
3153+
from torch.quantization.quantize_fx import _fuse_fx, QuantizationTracer
3154+
if model.kwargs is not None:
3155+
prepare_custom_config_dict = model.kwargs.get(
3156+
'prepare_custom_config_dict', {})
3157+
else:
3158+
prepare_custom_config_dict = {}
3159+
skipped_module_names = prepare_custom_config_dict.get(\
3160+
'non_traceable_module_name', [])
3161+
skipped_module_classes = prepare_custom_config_dict.get(\
3162+
'non_traceable_module_class', [])
3163+
try:
3164+
tracer = QuantizationTracer(
3165+
skipped_module_names, skipped_module_classes)
3166+
graph_module = GraphModule(tmp_model, tracer.trace(tmp_model))
3167+
if self.version >= PyTorchVersionMode.PT111.value: # pragma: no cover
3168+
# pylint: disable=E1124
3169+
fused_model = _fuse_fx(graph_module, is_qat,
3170+
fuse_custom_config_dict=prepare_custom_config_dict)
3171+
else:
3172+
fused_model = _fuse_fx(graph_module, prepare_custom_config_dict)
3173+
except:
3174+
self.sub_module_list = []
3175+
self._fuse_sub_graph(tmp_model, prefix='', is_qat=is_qat)
3176+
fused_model = tmp_model
31393177
except Exception as e: # pragma: no cover
3140-
tmp_model = model._model
3178+
fused_model = model._model
31413179
logger.warning("Deepcopy failed: {}, inplace=True now!".format(repr(e)))
3142-
tmp_model.train() if is_qat else tmp_model.eval()
3143-
from torch.fx import GraphModule
3144-
from torch.quantization.quantize_fx import _fuse_fx, QuantizationTracer
3145-
if model.kwargs is not None:
3146-
prepare_custom_config_dict = model.kwargs.get(
3147-
'prepare_custom_config_dict', {})
3148-
else:
3149-
prepare_custom_config_dict = {}
3150-
skipped_module_names = prepare_custom_config_dict.get(\
3151-
'non_traceable_module_name', [])
3152-
skipped_module_classes = prepare_custom_config_dict.get(\
3153-
'non_traceable_module_class', [])
3154-
try:
3155-
tracer = QuantizationTracer(
3156-
skipped_module_names, skipped_module_classes)
3157-
graph_module = GraphModule(tmp_model, tracer.trace(tmp_model))
3158-
if self.version >= PyTorchVersionMode.PT111.value: # pragma: no cover
3159-
# pylint: disable=E1124
3160-
fused_model = _fuse_fx(graph_module, is_qat,
3161-
fuse_custom_config_dict=prepare_custom_config_dict)
3162-
else:
3163-
fused_model = _fuse_fx(graph_module, prepare_custom_config_dict)
3164-
except:
3165-
self.sub_module_list = []
3166-
self._fuse_sub_graph(tmp_model, prefix='', is_qat=is_qat)
3167-
fused_model = tmp_model
31683180
return fused_model
31693181

31703182
def _fuse_sub_graph(self, model, prefix, is_qat):

neural_compressor/adaptor/pytorch_cpu.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
'QuantStub', 'FloatFunctional', 'ConvReLU2d', 'ConvReLU3d', 'LinearReLU', 'BNReLU2d',
3131
'BNReLU3d', 'ConvBn2d', 'ConvBnReLU2d']
3232
bf16: ['Linear', 'bmm', 'mm', 'baddbmm', 'addmm', 'addbmm',
33-
'_convolution', 'LSTM', 'LSTMCell', 'GRU', 'GRUCell', 'Tanh']
33+
'_convolution', 'LSTM', 'LSTMCell', 'GRU', 'GRUCell']
3434
fp32: ['*'] # '*' means all op types
3535

3636
capabilities: &1_11_capabilities

neural_compressor/utils/pytorch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ def _load_int8_orchestration(model, tune_cfg, stat_dict, **kwargs):
108108
version = get_torch_version()
109109
if version < PyTorchVersionMode.PT111.value:
110110
quantized_ops["default_qconfig"] = None
111+
else:
112+
from torch.ao.quantization import default_embedding_qat_qconfig
113+
for op in tune_cfg['quantizable_ops']:
114+
if op[1] in ['Embedding', 'EmbeddingBag']:
115+
quantized_ops[op[0]] = default_embedding_qat_qconfig
111116
fx_op_cfgs = _cfgs_to_fx_cfgs(quantized_ops, 'quant_aware_training')
112117
model.train()
113118
if tune_cfg['sub_module_list'] is None:

0 commit comments

Comments
 (0)