@@ -2717,13 +2717,25 @@ def _pre_hook_for_qat(self):
2717
2717
qscheme = torch .per_tensor_affine ,
2718
2718
reduce_range = REDUCE_RANGE ,
2719
2719
observer = torch .quantization .MovingAverageMinMaxObserver ),
2720
- weight = torch .quantization .default_weight_fake_quant )
2720
+ weight = torch .quantization .default_weight_fake_quant ) \
2721
+ if self .version < PyTorchVersionMode .PT110 .value else \
2722
+ torch .quantization .QConfig (
2723
+ activation = torch .quantization .FusedMovingAvgObsFakeQuantize .with_args (
2724
+ dtype = torch .quint8 ,
2725
+ qscheme = torch .per_tensor_affine ,
2726
+ reduce_range = REDUCE_RANGE ),
2727
+ weight = torch .quantization .default_fused_per_channel_wt_fake_quant )
2721
2728
quantizable_ops = []
2722
2729
tmp_model = self .fuse_fx_model (self .model , is_qat = True )
2723
2730
self ._get_quantizable_ops_recursively (tmp_model , '' , quantizable_ops )
2724
2731
quantized_ops = {op [0 ]:q_cfgs for op in quantizable_ops }
2725
2732
if self .version < PyTorchVersionMode .PT111 .value :
2726
2733
quantized_ops ["default_qconfig" ] = None
2734
+ else :
2735
+ from torch .ao .quantization import default_embedding_qat_qconfig
2736
+ for op in quantizable_ops :
2737
+ if op [1 ] in ['Embedding' , 'EmbeddingBag' ]:
2738
+ quantized_ops [op [0 ]] = default_embedding_qat_qconfig
2727
2739
from torch .quantization .quantize_fx import prepare_qat_fx
2728
2740
fx_op_cfgs = _cfgs_to_fx_cfgs (quantized_ops , 'quant_aware_training' )
2729
2741
self .model ._model .train ()
@@ -3136,35 +3148,35 @@ def fuse_fx_model(self, model, is_qat):
3136
3148
"""
3137
3149
try :
3138
3150
tmp_model = copy .deepcopy (model ._model )
3151
+ tmp_model .train () if is_qat else tmp_model .eval ()
3152
+ from torch .fx import GraphModule
3153
+ from torch .quantization .quantize_fx import _fuse_fx , QuantizationTracer
3154
+ if model .kwargs is not None :
3155
+ prepare_custom_config_dict = model .kwargs .get (
3156
+ 'prepare_custom_config_dict' , {})
3157
+ else :
3158
+ prepare_custom_config_dict = {}
3159
+ skipped_module_names = prepare_custom_config_dict .get (\
3160
+ 'non_traceable_module_name' , [])
3161
+ skipped_module_classes = prepare_custom_config_dict .get (\
3162
+ 'non_traceable_module_class' , [])
3163
+ try :
3164
+ tracer = QuantizationTracer (
3165
+ skipped_module_names , skipped_module_classes )
3166
+ graph_module = GraphModule (tmp_model , tracer .trace (tmp_model ))
3167
+ if self .version >= PyTorchVersionMode .PT111 .value : # pragma: no cover
3168
+ # pylint: disable=E1124
3169
+ fused_model = _fuse_fx (graph_module , is_qat ,
3170
+ fuse_custom_config_dict = prepare_custom_config_dict )
3171
+ else :
3172
+ fused_model = _fuse_fx (graph_module , prepare_custom_config_dict )
3173
+ except :
3174
+ self .sub_module_list = []
3175
+ self ._fuse_sub_graph (tmp_model , prefix = '' , is_qat = is_qat )
3176
+ fused_model = tmp_model
3139
3177
except Exception as e : # pragma: no cover
3140
- tmp_model = model ._model
3178
+ fused_model = model ._model
3141
3179
logger .warning ("Deepcopy failed: {}, inplace=True now!" .format (repr (e )))
3142
- tmp_model .train () if is_qat else tmp_model .eval ()
3143
- from torch .fx import GraphModule
3144
- from torch .quantization .quantize_fx import _fuse_fx , QuantizationTracer
3145
- if model .kwargs is not None :
3146
- prepare_custom_config_dict = model .kwargs .get (
3147
- 'prepare_custom_config_dict' , {})
3148
- else :
3149
- prepare_custom_config_dict = {}
3150
- skipped_module_names = prepare_custom_config_dict .get (\
3151
- 'non_traceable_module_name' , [])
3152
- skipped_module_classes = prepare_custom_config_dict .get (\
3153
- 'non_traceable_module_class' , [])
3154
- try :
3155
- tracer = QuantizationTracer (
3156
- skipped_module_names , skipped_module_classes )
3157
- graph_module = GraphModule (tmp_model , tracer .trace (tmp_model ))
3158
- if self .version >= PyTorchVersionMode .PT111 .value : # pragma: no cover
3159
- # pylint: disable=E1124
3160
- fused_model = _fuse_fx (graph_module , is_qat ,
3161
- fuse_custom_config_dict = prepare_custom_config_dict )
3162
- else :
3163
- fused_model = _fuse_fx (graph_module , prepare_custom_config_dict )
3164
- except :
3165
- self .sub_module_list = []
3166
- self ._fuse_sub_graph (tmp_model , prefix = '' , is_qat = is_qat )
3167
- fused_model = tmp_model
3168
3180
return fused_model
3169
3181
3170
3182
def _fuse_sub_graph (self , model , prefix , is_qat ):
0 commit comments