@@ -731,16 +731,6 @@ def __init__(self, framework_specific_info):
731
731
self .sub_module_list = None
732
732
self .default_qconfig = framework_specific_info ['default_qconfig' ] \
733
733
if 'default_qconfig' in framework_specific_info else None
734
- self .fused_op = ['nni.ConvReLU1d' ,
735
- 'nni.ConvReLU2d' ,
736
- 'nni.ConvReLU3d' ,
737
- 'nni.LinearReLU' ,
738
- 'nni.BNReLU2d' ,
739
- 'nni.BNReLU3d' ,
740
- 'nniqat.ConvReLU2d' ,
741
- 'nniqat.ConvBn2d' ,
742
- 'nniqat.ConvBnReLU2d' ,
743
- 'nni.LinearReLU' ]
744
734
745
735
if 'approach' in framework_specific_info : # pragma: no cover
746
736
self .approach = framework_specific_info ['approach' ]
@@ -1016,9 +1006,7 @@ def is_fused_module(self, module):
1016
1006
(bool): is fused or not
1017
1007
"""
1018
1008
op_type = str (type (module ))
1019
- op_type = op_type [op_type .rfind ('.' )+ 1 :].strip ('>' ).strip ('\' ' )
1020
- op_type = 'nni.' + op_type
1021
- if op_type in self .fused_op :
1009
+ if 'fused' in op_type :
1022
1010
return True
1023
1011
else :
1024
1012
return False
@@ -1385,8 +1373,18 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
1385
1373
None
1386
1374
"""
1387
1375
1388
- for name , child in model .named_children ():
1389
- op_name = prefix + '.' + name if prefix != '' else name
1376
+ module_dict = dict (model .named_modules ())
1377
+ for op_name , child in model .named_modules ():
1378
+ if self .is_fused_module (child ):
1379
+ for name , _ in child .named_children ():
1380
+ module_prefix = op_name + '.' + name
1381
+ if module_prefix in module_dict :
1382
+ module_dict .pop (module_prefix ) # remove sub-modules of fused modules
1383
+ if op_name in self .fused_dict :
1384
+ self .fused_dict [op_name ] = [self .fused_dict [op_name ], module_prefix ]
1385
+ else :
1386
+ self .fused_dict [op_name ] = module_prefix
1387
+ for op_name , child in module_dict .items ():
1390
1388
# there is accuracy issue in quantized LayerNorm op in pytorch <1.8.1,
1391
1389
# so remove it here
1392
1390
if op_name in self .non_quant_dict ['skipped_module_names' ] or \
@@ -1399,15 +1397,6 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
1399
1397
op_name , unify_op_type_mapping [str (child .__class__ .__name__ )]
1400
1398
if str (child .__class__ .__name__ ) in unify_op_type_mapping else
1401
1399
str (child .__class__ .__name__ )))
1402
- if self .is_fused_module (child ):
1403
- for name , _ in child .named_children ():
1404
- module_prefix = op_name + '.' + name
1405
- if op_name in self .fused_dict :
1406
- self .fused_dict [op_name ] = [self .fused_dict [op_name ], module_prefix ]
1407
- else :
1408
- self .fused_dict [op_name ] = module_prefix
1409
- else :
1410
- self ._get_quantizable_ops_recursively (child , op_name , quantizable_ops )
1411
1400
1412
1401
def _get_scale_zeropoint (self , model , tune_cfg ):
1413
1402
"""get activation scale and zero_point for converted model.
@@ -1422,9 +1411,7 @@ def _get_scale_zeropoint(self, model, tune_cfg):
1422
1411
None
1423
1412
"""
1424
1413
modules = dict (model .named_modules ())
1425
- for key in tune_cfg ['op' ]:
1426
- value = tune_cfg ['op' ][key ]
1427
- assert isinstance (value , dict )
1414
+ for key , value in tune_cfg ['op' ].items ():
1428
1415
if hasattr (modules [key [0 ]], 'scale' ):
1429
1416
value ['activation' ]['scale' ] = float (modules [key [0 ]].scale )
1430
1417
if hasattr (modules [key [0 ]], 'zero_point' ):
@@ -2965,17 +2952,14 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
2965
2952
None
2966
2953
"""
2967
2954
2968
- for name , child in model .named_children ():
2969
- op_name = prefix + '.' + name if prefix != '' else name
2955
+ for op_name , child in model .named_modules ():
2970
2956
if type (child ) in self .white_list \
2971
2957
and type (child ) != torch .nn .Sequential \
2972
2958
and type (child ) != torch .quantization .stubs .DeQuantStub :
2973
2959
quantizable_ops .append ((
2974
2960
op_name , unify_op_type_mapping [str (child .__class__ .__name__ )]
2975
2961
if str (child .__class__ .__name__ ) in unify_op_type_mapping else
2976
2962
str (child .__class__ .__name__ )))
2977
- else :
2978
- self ._get_quantizable_ops_recursively (child , op_name , quantizable_ops )
2979
2963
2980
2964
def _get_module_scale_zeropoint (self , model , tune_cfg , prefix = '' ):
2981
2965
"""get activation scale and zero_point for converted module.
0 commit comments