Skip to content

Commit 033e641

Browse files
authored
enhance get_quantizable_ops on PT (#1122)
1 parent 8e3c3c0 commit 033e641

File tree

1 file changed

+15
-31
lines changed

1 file changed

+15
-31
lines changed

neural_compressor/adaptor/pytorch.py

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -731,16 +731,6 @@ def __init__(self, framework_specific_info):
731731
self.sub_module_list = None
732732
self.default_qconfig = framework_specific_info['default_qconfig'] \
733733
if 'default_qconfig' in framework_specific_info else None
734-
self.fused_op = ['nni.ConvReLU1d',
735-
'nni.ConvReLU2d',
736-
'nni.ConvReLU3d',
737-
'nni.LinearReLU',
738-
'nni.BNReLU2d',
739-
'nni.BNReLU3d',
740-
'nniqat.ConvReLU2d',
741-
'nniqat.ConvBn2d',
742-
'nniqat.ConvBnReLU2d',
743-
'nni.LinearReLU']
744734

745735
if 'approach' in framework_specific_info: # pragma: no cover
746736
self.approach = framework_specific_info['approach']
@@ -1016,9 +1006,7 @@ def is_fused_module(self, module):
10161006
(bool): is fused or not
10171007
"""
10181008
op_type = str(type(module))
1019-
op_type = op_type[op_type.rfind('.')+1:].strip('>').strip('\'')
1020-
op_type = 'nni.' + op_type
1021-
if op_type in self.fused_op:
1009+
if 'fused' in op_type:
10221010
return True
10231011
else:
10241012
return False
@@ -1385,8 +1373,18 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
13851373
None
13861374
"""
13871375

1388-
for name, child in model.named_children():
1389-
op_name = prefix + '.' + name if prefix != '' else name
1376+
module_dict = dict(model.named_modules())
1377+
for op_name, child in model.named_modules():
1378+
if self.is_fused_module(child):
1379+
for name, _ in child.named_children():
1380+
module_prefix = op_name + '.' + name
1381+
if module_prefix in module_dict:
1382+
module_dict.pop(module_prefix) # remove sub-modules of fused modules
1383+
if op_name in self.fused_dict:
1384+
self.fused_dict[op_name] = [self.fused_dict[op_name], module_prefix]
1385+
else:
1386+
self.fused_dict[op_name] = module_prefix
1387+
for op_name, child in module_dict.items():
13901388
# there is accuracy issue in quantized LayerNorm op in pytorch <1.8.1,
13911389
# so remove it here
13921390
if op_name in self.non_quant_dict['skipped_module_names'] or \
@@ -1399,15 +1397,6 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
13991397
op_name, unify_op_type_mapping[str(child.__class__.__name__)]
14001398
if str(child.__class__.__name__) in unify_op_type_mapping else
14011399
str(child.__class__.__name__)))
1402-
if self.is_fused_module(child):
1403-
for name, _ in child.named_children():
1404-
module_prefix = op_name + '.' + name
1405-
if op_name in self.fused_dict:
1406-
self.fused_dict[op_name] = [self.fused_dict[op_name], module_prefix]
1407-
else:
1408-
self.fused_dict[op_name] = module_prefix
1409-
else:
1410-
self._get_quantizable_ops_recursively(child, op_name, quantizable_ops)
14111400

14121401
def _get_scale_zeropoint(self, model, tune_cfg):
14131402
"""get activation scale and zero_point for converted model.
@@ -1422,9 +1411,7 @@ def _get_scale_zeropoint(self, model, tune_cfg):
14221411
None
14231412
"""
14241413
modules = dict(model.named_modules())
1425-
for key in tune_cfg['op']:
1426-
value = tune_cfg['op'][key]
1427-
assert isinstance(value, dict)
1414+
for key, value in tune_cfg['op'].items():
14281415
if hasattr(modules[key[0]], 'scale'):
14291416
value['activation']['scale'] = float(modules[key[0]].scale)
14301417
if hasattr(modules[key[0]], 'zero_point'):
@@ -2965,17 +2952,14 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
29652952
None
29662953
"""
29672954

2968-
for name, child in model.named_children():
2969-
op_name = prefix + '.' + name if prefix != '' else name
2955+
for op_name, child in model.named_modules():
29702956
if type(child) in self.white_list \
29712957
and type(child) != torch.nn.Sequential \
29722958
and type(child) != torch.quantization.stubs.DeQuantStub:
29732959
quantizable_ops.append((
29742960
op_name, unify_op_type_mapping[str(child.__class__.__name__)]
29752961
if str(child.__class__.__name__) in unify_op_type_mapping else
29762962
str(child.__class__.__name__)))
2977-
else:
2978-
self._get_quantizable_ops_recursively(child, op_name, quantizable_ops)
29792963

29802964
def _get_module_scale_zeropoint(self, model, tune_cfg, prefix=''):
29812965
"""get activation scale and zero_point for converted module.

0 commit comments

Comments
 (0)