diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index 9e9370ef69..0024b52dbe 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -233,7 +233,7 @@ def annotate_matmul_input1(node: Node): ) quantization_config_8a4w_per_channel = get_ptq_per_channel_quant_config( act_dtype=torch.uint8, - weight_dtype="int4", + weight_dtype=torch.int4, act_observer=MinMaxObserver, act_symmetric=True, ) diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py index e2a9cd8356..748128ceaf 100644 --- a/backends/qualcomm/quantizer/qconfig.py +++ b/backends/qualcomm/quantizer/qconfig.py @@ -241,8 +241,7 @@ def get_ptq_per_channel_quant_config( torch.int8, torch.int16, } - # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype - supported_weight_dtypes = {"int4", torch.int8, torch.int16} + supported_weight_dtypes = {torch.int4, torch.int8, torch.int16} assert ( act_dtype in supported_act_types ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" @@ -276,9 +275,11 @@ def get_ptq_per_channel_quant_config( ) weight_quantization_spec = QuantizationSpec( - dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, - quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, - quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, + dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype, + quant_min=( + -7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1 + ), + quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max, qscheme=torch.per_channel_symmetric, ch_axis=0, observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), @@ -310,9 +311,11 @@ def get_ptq_per_block_quant_config( act_symmetric=act_symmetric, ) weight_quantization_spec = QuantizationSpec( - dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, - quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, - quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, + dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype, + quant_min=( + -7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1 + ), + quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max, qscheme=torch.per_channel_symmetric, ch_axis=0, observer_or_fake_quant_ctr=PerBlockParamObserver.with_args(**extra_args), @@ -463,8 +466,7 @@ def get_qat_per_channel_quant_config( torch.int8, torch.int16, } - # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype - supported_weight_dtypes = {"int4", torch.int8, torch.int16} + supported_weight_dtypes = {torch.int4, torch.int8, torch.int16} assert ( act_dtype in supported_act_types ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" @@ -491,17 +493,21 @@ def get_qat_per_channel_quant_config( ) weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( - dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, - quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, - quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, + dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype, + quant_min=( + -7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1 + ), + quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max, qscheme=torch.per_channel_symmetric, ch_axis=0, observer=MovingAveragePerChannelMinMaxObserver, ) weight_quantization_spec = QuantizationSpec( - dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, - quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, - quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, + dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype, + quant_min=( + -7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1 + ), + quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max, qscheme=torch.per_channel_symmetric, ch_axis=0, observer_or_fake_quant_ctr=weight_fake_quant_ctr, diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 9a149e7db8..7298e02aa0 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -85,7 +85,7 @@ class QuantDtype(IntEnum): partial( get_ptq_per_channel_quant_config, act_dtype=torch.uint16, - weight_dtype="int4", + weight_dtype=torch.int4, ), None, ), @@ -94,12 +94,12 @@ class QuantDtype(IntEnum): partial( get_ptq_per_channel_quant_config, act_dtype=torch.uint16, - weight_dtype="int4", + weight_dtype=torch.int4, ), partial( get_ptq_per_block_quant_config, act_dtype=torch.uint16, - weight_dtype="int4", + weight_dtype=torch.int4, ), ), (QuantDtype.use_8a8w, False): ( @@ -113,7 +113,7 @@ class QuantDtype(IntEnum): partial( get_qat_per_channel_quant_config, act_dtype=torch.uint16, - weight_dtype="int4", + weight_dtype=torch.int4, ), None, ),