Skip to content

Commit 3a88875

Browse files
Some fixes for AWQ (#269)
* Some fixes for AWQ * revert clamp to 1e-5 Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> * rename awq quant preset to W4A16_ASYM Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> * revert changes to min_vals/max_vals Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> * only round if casting to int type Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> --------- Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> Co-authored-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent 545e426 commit 3a88875

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,18 @@ def is_preset_scheme(name: str) -> bool:
142142
),
143143
)
144144

145+
# 4 bit integer weights only asymmetric quantization
146+
W4A16_ASYM = dict(
147+
weights=QuantizationArgs(
148+
num_bits=4,
149+
type=QuantizationType.INT,
150+
strategy=QuantizationStrategy.GROUP,
151+
group_size=128,
152+
symmetric=False,
153+
dynamic=False,
154+
),
155+
)
156+
145157
# 4 bit integer weights and 8 bit activations quantization
146158
INT8_W4A8 = dict(
147159
weights=QuantizationArgs(
@@ -205,6 +217,7 @@ def is_preset_scheme(name: str) -> bool:
205217
# Integer weight only schemes
206218
"W8A16": W8A16,
207219
"W4A16": W4A16,
220+
"W4A16_ASYM": W4A16_ASYM,
208221
# Integer weight and activation schemes
209222
"W8A8": INT8_W8A8,
210223
"INT8": INT8_W8A8, # alias for W8A8

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,11 @@ def calculate_qparams(
6464
:param quantization_args: settings to quantization
6565
:return: tuple of the calculated scale(s) and zero point(s)
6666
"""
67+
# based on the implementations for consuming quantized values,
68+
# 0.0 must always be representable within the quantized range
6769
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
6870
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
71+
6972
device = min_vals.device
7073

7174
bit_min, bit_max = calculate_range(quantization_args, device)
@@ -84,6 +87,9 @@ def calculate_qparams(
8487
zero_points = torch.clamp(zero_points, bit_min, bit_max)
8588

8689
# match zero-points to quantized type
90+
# if casting to int, use round instead of truncate
91+
if quantization_args.type == QuantizationType.INT:
92+
zero_points = torch.round(zero_points)
8793
zero_points = zero_points.to(zp_dtype)
8894

8995
if scales.ndim == 0:
@@ -96,7 +102,7 @@ def calculate_qparams(
96102
def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs):
97103
"""
98104
Returns the computed scales and zero points for dynamic activation
99-
qunatization.
105+
quantization.
100106
101107
:param value: tensor to calculate quantization parameters for
102108
:param args: quantization args

0 commit comments

Comments
 (0)