Skip to content

Commit 04fb450

Browse files
authored
Remove preserve_zero and zero_point_domain from choose_qparams_affine (#2149)
1 parent 212d912 commit 04fb450

File tree

14 files changed

+1014
-520
lines changed

14 files changed

+1014
-520
lines changed

test/dtypes/test_uintx.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from torchao.quantization.quant_api import quantize_, uintx_weight_only
1111
from torchao.quantization.quant_primitives import (
1212
MappingType,
13-
ZeroPointDomain,
1413
choose_qparams_affine,
1514
dequantize_affine,
1615
quantize_affine,
@@ -112,7 +111,6 @@ def test_uintx_weight_only_quant(dtype, group_size, device):
112111
mapping_type = MappingType.SYMMETRIC
113112
eps = torch.finfo(torch.float32).eps
114113
zero_point_dtype = torch.int32
115-
zero_point_domain = ZeroPointDomain.INT
116114
block_size = (1, group_size)
117115

118116
scale, zero_point = choose_qparams_affine(
@@ -123,8 +121,6 @@ def test_uintx_weight_only_quant(dtype, group_size, device):
123121
eps=eps,
124122
scale_dtype=torch.float32,
125123
zero_point_dtype=zero_point_dtype,
126-
preserve_zero=True,
127-
zero_point_domain=zero_point_domain,
128124
)
129125

130126
aqt = quantize_affine(
@@ -133,15 +129,12 @@ def test_uintx_weight_only_quant(dtype, group_size, device):
133129
scale,
134130
zero_point,
135131
dtype,
136-
zero_point_domain=zero_point_domain,
137132
)
138133
# Note: output will be uint8 tensor for sub byte tensors for now
139134

140135
q = to_uintx(aqt, dtype, -1)
141136
assert q is not None, "quantization failed"
142-
deqaunt = dequantize_affine(
143-
q, block_size, scale, zero_point, dtype, zero_point_domain=zero_point_domain
144-
)
137+
deqaunt = dequantize_affine(q, block_size, scale, zero_point, dtype)
145138
assert deqaunt is not None, "deqauntization failed"
146139

147140

test/quantization/test_quant_primitives.py

Lines changed: 27 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,16 @@
99
import unittest
1010

1111
import torch
12-
from parameterized import parameterized
1312

14-
from torchao.float8.float8_utils import EPS as float8_eps
1513
from torchao.quantization.quant_primitives import (
1614
MappingType,
1715
ZeroPointDomain,
1816
choose_qparams_affine,
19-
choose_qparams_affine_float8,
17+
choose_qparams_affine_tinygemm,
2018
dequantize_affine,
21-
dequantize_affine_float8,
2219
fake_quantize_affine,
2320
fake_quantize_affine_cachemask,
2421
quantize_affine,
25-
quantize_affine_float8,
2622
)
2723

2824
# TODO: remove test for utils?
@@ -650,35 +646,6 @@ def test_raises(self):
650646
with self.assertRaisesRegex(RuntimeError, "is invalid for input of size 1"):
651647
_ = quantize_affine(input, block_size, scale, zero_point, dtype)
652648

653-
def test_not_preserve_zero_not_supported(self):
654-
"""Making sure preserve_zero == False is not supported for symmetric quant"""
655-
input = torch.randn(10, 256)
656-
n_bit = 4
657-
mapping_type = MappingType.SYMMETRIC
658-
dtype = torch.int8
659-
block_size = (1, 128)
660-
quant_min = 0
661-
quant_max = 2**n_bit - 1
662-
eps = 1e-6
663-
scale_dtype = torch.bfloat16
664-
zero_point_dtype = torch.bfloat16
665-
with self.assertRaisesRegex(
666-
ValueError,
667-
"preserve_zero == False is not supported for symmetric quantization",
668-
):
669-
choose_qparams_affine(
670-
input,
671-
mapping_type,
672-
block_size,
673-
dtype,
674-
quant_min,
675-
quant_max,
676-
eps,
677-
scale_dtype=scale_dtype,
678-
zero_point_dtype=zero_point_dtype,
679-
preserve_zero=False,
680-
)
681-
682649
def test_get_groupwise_affine_qparams(self):
683650
input = torch.randn(10, 256)
684651
n_bit = 4
@@ -702,22 +669,33 @@ def test_get_groupwise_affine_qparams(self):
702669
dtype=torch.bfloat16,
703670
zero_point_domain=zero_point_domain,
704671
)
705-
scale, zero_point = choose_qparams_affine(
706-
input,
707-
mapping_type,
708-
block_size,
709-
dtype,
710-
quant_min,
711-
quant_max,
712-
eps,
713-
scale_dtype=scale_dtype,
714-
zero_point_dtype=zero_point_dtype,
715-
preserve_zero=zero_point_domain == ZeroPointDomain.INT,
716-
zero_point_domain=zero_point_domain,
717-
)
672+
if zero_point_domain == ZeroPointDomain.FLOAT:
673+
scale, zero_point = choose_qparams_affine_tinygemm(
674+
input,
675+
mapping_type,
676+
block_size,
677+
dtype,
678+
quant_min,
679+
quant_max,
680+
eps,
681+
scale_dtype=scale_dtype,
682+
zero_point_dtype=zero_point_dtype,
683+
)
684+
else:
685+
scale, zero_point = choose_qparams_affine(
686+
input,
687+
mapping_type,
688+
block_size,
689+
dtype,
690+
quant_min,
691+
quant_max,
692+
eps,
693+
scale_dtype=scale_dtype,
694+
zero_point_dtype=zero_point_dtype,
695+
)
718696

719-
self.assertTrue(torch.equal(scale, scale_ref))
720-
self.assertTrue(torch.equal(zero_point, zero_point_ref))
697+
self.assertTrue(torch.equal(scale, scale_ref))
698+
self.assertTrue(torch.equal(zero_point, zero_point_ref))
721699

722700
def test_groupwise_affine_quantize_tensor_from_qparams(self):
723701
input = torch.randn(10, 256)
@@ -847,120 +825,6 @@ def test_fake_quantize_affine_cachemask(self):
847825
torch.testing.assert_close(dequantized, fake_quantized)
848826
torch.testing.assert_close(expected_mask, mask)
849827

850-
def test_none_zero_point_domain(self):
851-
"""A None value for a ZeroPointDomain should not work, but ZeroPointDomain.NONE should"""
852-
input = torch.randn(10, 256)
853-
mapping_type = MappingType.SYMMETRIC
854-
dtype = torch.int8
855-
block_size = (1, 128)
856-
quant_min = None
857-
quant_max = None
858-
eps = 1e-6
859-
scale_dtype = torch.float32
860-
zero_point_dtype = torch.int64
861-
try:
862-
_, zero_point = choose_qparams_affine(
863-
input,
864-
mapping_type,
865-
block_size,
866-
dtype,
867-
quant_min,
868-
quant_max,
869-
eps,
870-
scale_dtype=scale_dtype,
871-
zero_point_dtype=zero_point_dtype,
872-
preserve_zero=True,
873-
zero_point_domain=None,
874-
)
875-
except ValueError:
876-
# This exception was expected
877-
# Now test for ZeroPointDomain.NONE
878-
_, zero_point = choose_qparams_affine(
879-
input,
880-
mapping_type,
881-
block_size,
882-
dtype,
883-
quant_min,
884-
quant_max,
885-
eps,
886-
scale_dtype=scale_dtype,
887-
zero_point_dtype=zero_point_dtype,
888-
preserve_zero=True,
889-
zero_point_domain=ZeroPointDomain.NONE,
890-
)
891-
self.assertTrue(zero_point is None)
892-
else:
893-
# An exception should have been thrown for zero_point_domain None
894-
self.assertTrue(
895-
False,
896-
msg="A runtime exception should have been thrown for zero_point_domain None",
897-
)
898-
899-
@parameterized.expand(
900-
[
901-
(
902-
torch.float32,
903-
torch.float8_e4m3fn,
904-
),
905-
(
906-
torch.float32,
907-
torch.float8_e5m2,
908-
),
909-
(
910-
torch.bfloat16,
911-
torch.float8_e4m3fn,
912-
),
913-
(
914-
torch.bfloat16,
915-
torch.float8_e5m2,
916-
),
917-
]
918-
)
919-
def test_float8_quant_primitives(self, hp_dtype, float8_dtype):
920-
input = torch.randn(10, 10)
921-
922-
# float8 quantization primitives
923-
scale = choose_qparams_affine_float8(input, float8_dtype=float8_dtype)
924-
quantized = quantize_affine_float8(input, scale, float8_dtype=float8_dtype)
925-
dequantized = dequantize_affine_float8(quantized, scale, output_dtype=hp_dtype)
926-
927-
# reference implementation using generic primitives
928-
expected_scale, _ = choose_qparams_affine(
929-
input,
930-
MappingType.SYMMETRIC,
931-
input.shape,
932-
float8_dtype,
933-
eps=float8_eps, # use same EPS as float8 training
934-
scale_dtype=torch.float32,
935-
quant_min=torch.finfo(float8_dtype).min,
936-
quant_max=torch.finfo(float8_dtype).max,
937-
)
938-
expected_quantized = quantize_affine(
939-
input,
940-
input.shape,
941-
scale,
942-
output_dtype=float8_dtype,
943-
quant_min=torch.finfo(float8_dtype).min,
944-
quant_max=torch.finfo(float8_dtype).max,
945-
zero_point=None,
946-
zero_point_domain=ZeroPointDomain.NONE,
947-
)
948-
expected_dequantized = dequantize_affine(
949-
expected_quantized,
950-
input.shape,
951-
scale,
952-
input_dtype=float8_dtype,
953-
output_dtype=hp_dtype,
954-
quant_min=torch.finfo(float8_dtype).min,
955-
quant_max=torch.finfo(float8_dtype).max,
956-
zero_point=None,
957-
zero_point_domain=ZeroPointDomain.NONE,
958-
)
959-
960-
self.assertTrue(torch.equal(expected_scale, scale))
961-
torch.testing.assert_close(expected_quantized, quantized)
962-
torch.testing.assert_close(expected_dequantized, dequantized)
963-
964828

965829
if __name__ == "__main__":
966830
unittest.main()

test/sparsity/test_marlin.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from torchao.quantization.quant_api import int4_weight_only, quantize_
1515
from torchao.quantization.quant_primitives import (
1616
MappingType,
17-
ZeroPointDomain,
1817
choose_qparams_affine,
1918
quantize_affine,
2019
)
@@ -92,8 +91,6 @@ def test_pack_unpack_equivalence(self):
9291
eps = 1e-6
9392
zero_point_dtype = torch.bfloat16
9493
mapping_type = MappingType.SYMMETRIC
95-
preserve_zero = True
96-
zero_point_domain = ZeroPointDomain.INT
9794
scale_dtype = None
9895

9996
w = torch.rand(shape, dtype=torch.float16, device="cuda")
@@ -112,8 +109,6 @@ def test_pack_unpack_equivalence(self):
112109
eps,
113110
scale_dtype,
114111
zero_point_dtype,
115-
preserve_zero,
116-
zero_point_domain,
117112
)
118113
w_q_24 = quantize_affine(
119114
w_24,
@@ -123,7 +118,6 @@ def test_pack_unpack_equivalence(self):
123118
target_dtype,
124119
quant_min,
125120
quant_max,
126-
zero_point_domain,
127121
)
128122
scales = scales.reshape(-1, w_q_24.shape[1])
129123

0 commit comments

Comments
 (0)