Skip to content

Commit 6148fef

Browse files
authored
[Bugfix] Fix shape calculation for group quantization (#308)
* use ceil of group size Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * use math, not torch Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix false assumption in tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 1068c84 commit 6148fef

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
import logging
17+
import math
1718
from enum import Enum
1819
from typing import Optional
1920

@@ -162,7 +163,7 @@ def _initialize_scale_zero_point(
162163
# (output_channels, 1)
163164
expected_shape = (weight_shape[0], 1)
164165
elif quantization_args.strategy == QuantizationStrategy.GROUP:
165-
num_groups = weight_shape[1] // quantization_args.group_size
166+
num_groups = math.ceil(weight_shape[1] / quantization_args.group_size)
166167
expected_shape = (weight_shape[0], max(num_groups, 1))
167168

168169
scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype

tests/test_quantization/lifecycle/test_initialize.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515

16+
import math
17+
1618
import pytest
1719
from compressed_tensors.quantization import (
1820
ActivationOrdering,
@@ -183,7 +185,7 @@ def test_initialize_quantization_parameters(weights, input_activations):
183185
expected_shape = (layer.weight.shape[0], 1)
184186

185187
elif args.strategy == QuantizationStrategy.GROUP: # only weight
186-
num_groups = layer.weight.shape[1] // args.group_size
188+
num_groups = math.ceil(layer.weight.shape[1] / args.group_size)
187189
expected_shape = (layer.weight.shape[0], max(num_groups, 1))
188190

189191
elif args.strategy == QuantizationStrategy.BLOCK:

0 commit comments

Comments
 (0)