Skip to content

Commit b10876b

Browse files
authored
Bump Int4WeightOnlyConfig version to 2 (#2949)
Bump int4 weight only config version to 2 Summary: Current Int4WeightOnlyConfig has version 1 and 2, and default is 1, this PR changes the default to 2 and made modification to callsites. For the Int4WeightOnlyConfig that's using the old configuration, we added explicit `version=1`, we can migrate the callsite to use the version 2 separately For READMEs we migrate the usage to version 2 directly Deprecation: TODO Test Plan: Regression tests: python test/dtypes/test_affine_quantized.py python test/quantization/test_quant_api.py python test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py Reviewers: Subscribers: Tasks: Tags:
1 parent 8b72284 commit b10876b

File tree

18 files changed

+131
-45
lines changed

18 files changed

+131
-45
lines changed

.github/scripts/torchao_model_releases/quantize_and_upload.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def _untie_weights_and_save_locally(model_id):
206206

207207
_int4_quant_code = """
208208
from torchao.quantization import Int4WeightOnlyConfig
209-
quant_config = Int4WeightOnlyConfig(group_size=128, packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq", version=2)
209+
quant_config = Int4WeightOnlyConfig(group_size=128, int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq")
210210
quantization_config = TorchAoConfig(quant_type=quant_config)
211211
quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
212212
tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -256,7 +256,7 @@ def _untie_weights_and_save_locally(model_id):
256256
)
257257
tokenizer = AutoTokenizer.from_pretrained(model_id)
258258
259-
base_config = Int4WeightOnlyConfig(group_size=128, version=2)
259+
base_config = Int4WeightOnlyConfig(group_size=128)
260260
quant_config = AWQConfig(base_config, step="prepare")
261261
quantize_(
262262
model,
@@ -633,9 +633,8 @@ def quantize_and_upload(
633633
"FP8": Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
634634
"INT4": Int4WeightOnlyConfig(
635635
group_size=128,
636-
packing_format="tile_packed_to_4d",
636+
int4_packing_format="tile_packed_to_4d",
637637
int4_choose_qparams_algorithm="hqq",
638-
version=2,
639638
),
640639
"INT8-INT4": ModuleFqnToConfig(
641640
{
@@ -669,7 +668,7 @@ def quantize_and_upload(
669668
)
670669
tokenizer = AutoTokenizer.from_pretrained(model_id)
671670

672-
base_config = Int4WeightOnlyConfig(group_size=128, version=2)
671+
base_config = Int4WeightOnlyConfig(group_size=128)
673672
quant_config = AWQConfig(base_config, step="prepare")
674673
quantize_(
675674
model,

test/dtypes/test_affine_quantized_tensor_parallel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def test_tp(self, dtype):
145145

146146
class TestInt4woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
147147
QUANT_METHOD_FN = staticmethod(int4_weight_only)
148+
QUANT_METHOD_KWARGS = {"version": 1}
148149
COMMON_DTYPES = [torch.bfloat16]
149150

150151
@common_utils.parametrize("dtype", COMMON_DTYPES)

test/integration/test_load_and_run_checkpoint.py

Lines changed: 93 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,22 @@
2424

2525
# please check model card for how to generate these models
2626

27-
_DEPRECATED_SINGLE_LINEAR_MODEL_NAMES = [
27+
# high precision model, used for testing config deprecation warning
28+
_HIGH_PRECISION_MODEL = "facebook/opt-125m"
29+
30+
_DEPRECATED_SINGLE_LINEAR_MODEL_INFO = [
2831
# model card: https://huggingface.co/torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev
29-
"torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev"
32+
(
33+
"torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev",
34+
1,
35+
"Float8DynamicActivationFloat8WeightConfig",
36+
),
37+
# model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-v1-0.14.dev
38+
(
39+
"torchao-testing/single-linear-Int4WeightOnlyConfig-v1-0.14.dev",
40+
1,
41+
"Int4WeightOnlyConfig",
42+
),
3043
]
3144

3245
_DEPRECATED_MODEL_INFO = [
@@ -36,15 +49,33 @@
3649
1,
3750
"Float8DynamicActivationFloat8WeightConfig",
3851
),
52+
# model card: https://huggingface.co/torchao-testing/opt-125m-Int4WeightOnlyConfig-v1-0.14.dev
53+
(
54+
"torchao-testing/opt-125m-Int4WeightOnlyConfig-v1-0.14.dev",
55+
1,
56+
"Int4WeightOnlyConfig",
57+
),
3958
]
4059

41-
_SINGLE_LINEAR_MODEL_NAMES = [
60+
_SINGLE_LINEAR_MODEL_INFO = [
4261
# model card: https://huggingface.co/torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev
43-
"torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev",
62+
(
63+
"torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev",
64+
2,
65+
"Float8DynamicActivationFloat8WeightConfig",
66+
),
4467
# model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev
45-
"torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev",
68+
(
69+
"torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev",
70+
2,
71+
"Int4WeightOnlyConfig",
72+
),
4673
# model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev
47-
"torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev",
74+
(
75+
"torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev",
76+
2,
77+
"Int4WeightOnlyConfig",
78+
),
4879
]
4980

5081

@@ -55,7 +86,9 @@
5586
"Skipping the test in fbcode for now, not sure how to download from transformers",
5687
)
5788
class TestLoadAndRunCheckpoint(TestCase):
58-
def _test_single_linear_helper(self, model_name):
89+
def _test_single_linear_helper(
90+
self, model_name, version, config_name, is_deprecated
91+
):
5992
from huggingface_hub import hf_hub_download
6093

6194
downloaded_model = hf_hub_download(model_name, filename="model.pt")
@@ -69,8 +102,20 @@ def _test_single_linear_helper(self, model_name):
69102
model = torch.nn.Sequential(
70103
torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda")
71104
)
72-
with open(downloaded_model, "rb") as f:
105+
106+
with (
107+
open(downloaded_model, "rb") as f,
108+
warnings.catch_warnings(record=True) as caught_warnings,
109+
):
73110
model.load_state_dict(torch.load(f), assign=True)
111+
if is_deprecated:
112+
assert any(
113+
f"Models quantized with version {version} of {config_name} is deprecated"
114+
in str(w.message)
115+
for w in caught_warnings
116+
), (
117+
f"Didn't get expected warning message for deprecation for model: {model_name}"
118+
)
74119

75120
downloaded_example_inputs = hf_hub_download(
76121
model_name, filename="model_inputs.pt"
@@ -84,17 +129,23 @@ def _test_single_linear_helper(self, model_name):
84129
output = model(*example_inputs)
85130
self.assertTrue(torch.equal(output, ref_output))
86131

87-
@common_utils.parametrize("model_name", _DEPRECATED_SINGLE_LINEAR_MODEL_NAMES)
88-
def test_deprecated_single_linear(self, model_name):
89-
self._test_single_linear_helper(model_name)
132+
@common_utils.parametrize("model_info", _DEPRECATED_SINGLE_LINEAR_MODEL_INFO)
133+
def test_deprecated_single_linear(self, model_info):
134+
model_name, version, config_name = model_info
135+
self._test_single_linear_helper(
136+
model_name, version, config_name, is_deprecated=True
137+
)
90138

91-
@common_utils.parametrize("model_name", _SINGLE_LINEAR_MODEL_NAMES)
92-
def test_single_linear(self, model_name):
139+
@common_utils.parametrize("model_info", _SINGLE_LINEAR_MODEL_INFO)
140+
def test_single_linear(self, model_info):
93141
"""Test that we can load and run the quantized linear checkpoint with saved sample input
94142
and match the saved output, to make sure there is no BC breaking changes
95143
when we make changes to tensor subclass implementations
96144
"""
97-
self._test_single_linear_helper(model_name)
145+
model_name, version, config_name = model_info
146+
self._test_single_linear_helper(
147+
model_name, version, config_name, is_deprecated=False
148+
)
98149

99150
@common_utils.parametrize("model_info", _DEPRECATED_MODEL_INFO)
100151
def test_deprecated_hf_models(self, model_info):
@@ -109,17 +160,23 @@ def test_deprecated_hf_models(self, model_info):
109160
torch_dtype="bfloat16",
110161
device_map="cuda:0",
111162
)
163+
# version mismatch check in config.py
112164
assert any(
113165
"Stored version is not the same as current default version of the config"
114166
in str(w.message)
115167
for w in caught_warnings
116-
), "Didn't get expected warning message for version mismatch"
168+
), (
169+
f"Didn't get expected warning message for version mismatch for config {config_name}, model {model_name}"
170+
)
117171

172+
# checkpoint deprecation
118173
assert any(
119-
f"Models quantized with version 1 of {config_name} is deprecated"
174+
f"Models quantized with version {version} of {config_name} is deprecated"
120175
in str(w.message)
121176
for w in caught_warnings
122-
), "Didn't get expected warning message for deprecation"
177+
), (
178+
f"Didn't get expected warning message for deprecation for model {model_name}"
179+
)
123180
assert isinstance(quantized_model.config.quantization_config, TorchAoConfig)
124181
assert (
125182
quantized_model.config.quantization_config.quant_type.version == version
@@ -139,7 +196,8 @@ def test_deprecated_hf_models(self, model_info):
139196
return_tensors="pt",
140197
).to("cuda")
141198
generated_ids = quantized_model.generate(
142-
**inputs, max_new_tokens=128, temperature=0
199+
**inputs,
200+
max_new_tokens=128,
143201
)
144202

145203
downloaded_output = hf_hub_download(model_name, filename="model_output.pt")
@@ -153,6 +211,23 @@ def test_deprecated_hf_models(self, model_info):
153211
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
154212
)
155213

214+
# make sure we throw warning for config deprecation
215+
with warnings.catch_warnings(record=True) as caught_warnings:
216+
_ = AutoModelForCausalLM.from_pretrained(
217+
_HIGH_PRECISION_MODEL,
218+
torch_dtype="bfloat16",
219+
device_map="cuda:0",
220+
quantization_config=quantized_model.config.quantization_config,
221+
)
222+
# config version deprecation in quant_api.py
223+
assert any(
224+
f"Config Deprecation: version {version} of {config_name} is deprecated and will no longer be supported in a future release"
225+
in str(w.message)
226+
for w in caught_warnings
227+
), (
228+
f"Didn't get expected warning message for version deprecation for config {config_name}, model {model_name}"
229+
)
230+
156231

157232
common_utils.instantiate_parametrized_tests(TestLoadAndRunCheckpoint)
158233

test/prototype/test_awq.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_awq_functionality(self):
7373
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
7474

7575
# baseline quantization
76-
base_config = Int4WeightOnlyConfig(group_size=group_size, version=2)
76+
base_config = Int4WeightOnlyConfig(group_size=group_size)
7777
m_baseline = copy.deepcopy(m)
7878
quantize_(m_baseline, base_config)
7979

@@ -123,7 +123,7 @@ def test_awq_loading(self):
123123
calibration_data = dataset[:n_calibration_examples]
124124

125125
# calibrate
126-
base_config = Int4WeightOnlyConfig(group_size=group_size, version=2)
126+
base_config = Int4WeightOnlyConfig(group_size=group_size)
127127
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
128128
quantize_(m, quant_config)
129129

@@ -177,7 +177,7 @@ def test_awq_loading_vllm(self):
177177
calibration_data = dataset[:n_calibration_examples]
178178

179179
# calibrate
180-
base_config = Int4WeightOnlyConfig(group_size=group_size, version=2)
180+
base_config = Int4WeightOnlyConfig(group_size=group_size)
181181
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
182182
quantize_(m, quant_config)
183183

test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
BF16_ACT_CONFIG = Int4WeightOnlyConfig(
2828
group_size=128,
2929
int4_packing_format="marlin_sparse",
30-
version=2,
3130
)
3231

3332

test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def get_config(group_size):
2929
return Int4WeightOnlyConfig(
3030
group_size=group_size,
3131
int4_packing_format="opaque",
32-
version=2,
3332
)
3433

3534

test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def get_config(group_size):
2929
return Int4WeightOnlyConfig(
3030
group_size=group_size,
3131
int4_packing_format="plain_int32",
32-
version=2,
3332
)
3433

3534

test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
BF16_ACT_CONFIG = Int4WeightOnlyConfig(
3131
group_size=128,
3232
int4_packing_format="preshuffled",
33-
version=2,
3433
)
3534

3635
# only 128 group_size is supported

test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def setUp(self):
3535
self.config = Int4WeightOnlyConfig(
3636
group_size=128,
3737
int4_packing_format="plain",
38-
version=2,
3938
)
4039
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
4140

test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,12 @@
2525
INT4_CONFIG = Int4WeightOnlyConfig(
2626
group_size=128,
2727
int4_packing_format="tile_packed_to_4d",
28-
version=2,
2928
)
3029

3130
INT4_HQQ_CONFIG = Int4WeightOnlyConfig(
3231
group_size=128,
3332
int4_packing_format="tile_packed_to_4d",
3433
int4_choose_qparams_algorithm="hqq",
35-
version=2,
3634
)
3735

3836

0 commit comments

Comments
 (0)