Skip to content

Commit 7acf6fa

Browse files
committed
Deprecate config functions like int4_weight_only
**Summary:** These have been superseded by `AOBaseConfig` objects for several releases already, but we never deprecated them. We will keep them around for another release before breaking BC and removing them. **Test Plan:** ``` python test/quantization/test_quant_api.py -k test_config_deprecation ``` [ghstack-poisoned]
1 parent 10ba659 commit 7acf6fa

File tree

3 files changed

+84
-9
lines changed

3 files changed

+84
-9
lines changed

test/quantization/test_quant_api.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import gc
1111
import tempfile
1212
import unittest
13+
import warnings
1314
from pathlib import Path
1415

1516
import torch
@@ -752,6 +753,50 @@ def test_int4wo_cuda_serialization(self):
752753
# load state_dict in cuda
753754
model.load_state_dict(sd, assign=True)
754755

756+
def test_config_deprecation(self):
757+
"""
758+
Test that old config functions like `int4_weight_only` trigger deprecation warnings.
759+
"""
760+
from torchao.quantization import (
761+
float8_weight_only,
762+
fpx_weight_only,
763+
gemlite_uintx_weight_only,
764+
int4_weight_only,
765+
int8_dynamic_activation_int4_weight,
766+
int8_dynamic_activation_int8_weight,
767+
int8_weight_only,
768+
uintx_weight_only,
769+
)
770+
771+
# Reset deprecation warning state, otherwise we won't log warnings here
772+
warnings.resetwarnings()
773+
774+
# Map from deprecated API to the args needed to instantiate it
775+
deprecated_apis_to_args = {
776+
float8_weight_only: (),
777+
fpx_weight_only: (3, 2),
778+
gemlite_uintx_weight_only: (),
779+
int4_weight_only: (),
780+
int8_dynamic_activation_int4_weight: (),
781+
int8_dynamic_activation_int8_weight: (),
782+
int8_weight_only: (),
783+
uintx_weight_only: (torch.uint4,),
784+
}
785+
786+
with warnings.catch_warnings(record=True) as _warnings:
787+
# Call each deprecated API twice
788+
for cls, args in deprecated_apis_to_args.items():
789+
cls(*args)
790+
cls(*args)
791+
792+
# Each call should trigger the warning only once
793+
self.assertEqual(len(_warnings), len(deprecated_apis_to_args))
794+
for w in _warnings:
795+
self.assertIn(
796+
"is deprecated and will be removed in a future release",
797+
str(w.message),
798+
)
799+
755800

756801
common_utils.instantiate_parametrized_tests(TestQuantFlow)
757802

torchao/quantization/quant_api.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
to_weight_tensor_with_linear_activation_quantization_metadata,
9393
)
9494
from torchao.utils import (
95+
_ConfigDeprecationWrapper,
9596
_is_fbgemm_genai_gpu_available,
9697
is_MI300,
9798
is_sm_at_least_89,
@@ -639,7 +640,9 @@ def __post_init__(self):
639640

640641

641642
# for BC
642-
int8_dynamic_activation_int4_weight = Int8DynamicActivationInt4WeightConfig
643+
int8_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
644+
"int8_dynamic_activation_int4_weight", Int8DynamicActivationInt4WeightConfig
645+
)
643646

644647

645648
@register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig)
@@ -1018,7 +1021,9 @@ def __post_init__(self):
10181021

10191022

10201023
# for BC
1021-
gemlite_uintx_weight_only = GemliteUIntXWeightOnlyConfig
1024+
gemlite_uintx_weight_only = _ConfigDeprecationWrapper(
1025+
"gemlite_uintx_weight_only", GemliteUIntXWeightOnlyConfig
1026+
)
10221027

10231028

10241029
@register_quantize_module_handler(GemliteUIntXWeightOnlyConfig)
@@ -1100,7 +1105,7 @@ def __post_init__(self):
11001105

11011106
# for BC
11021107
# TODO maybe change other callsites
1103-
int4_weight_only = Int4WeightOnlyConfig
1108+
int4_weight_only = _ConfigDeprecationWrapper("int4_weight_only", Int4WeightOnlyConfig)
11041109

11051110

11061111
def _int4_weight_only_quantize_tensor(weight, config):
@@ -1310,7 +1315,7 @@ def __post_init__(self):
13101315

13111316

13121317
# for BC
1313-
int8_weight_only = Int8WeightOnlyConfig
1318+
int8_weight_only = _ConfigDeprecationWrapper("int8_weight_only", Int8WeightOnlyConfig)
13141319

13151320

13161321
def _int8_weight_only_quantize_tensor(weight, config):
@@ -1471,7 +1476,9 @@ def __post_init__(self):
14711476

14721477

14731478
# for BC
1474-
int8_dynamic_activation_int8_weight = Int8DynamicActivationInt8WeightConfig
1479+
int8_dynamic_activation_int8_weight = _ConfigDeprecationWrapper(
1480+
"int8_dynamic_activation_int8_weight", Int8DynamicActivationInt8WeightConfig
1481+
)
14751482

14761483

14771484
def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
@@ -1580,7 +1587,9 @@ def __post_init__(self):
15801587

15811588

15821589
# for BC
1583-
float8_weight_only = Float8WeightOnlyConfig
1590+
float8_weight_only = _ConfigDeprecationWrapper(
1591+
"float8_weight_only", Float8WeightOnlyConfig
1592+
)
15841593

15851594

15861595
def _float8_weight_only_quant_tensor(weight, config):
@@ -1994,7 +2003,9 @@ def __post_init__(self):
19942003

19952004

19962005
# for BC
1997-
uintx_weight_only = UIntXWeightOnlyConfig
2006+
uintx_weight_only = _ConfigDeprecationWrapper(
2007+
"uintx_weight_only", UIntXWeightOnlyConfig
2008+
)
19982009

19992010

20002011
@register_quantize_module_handler(UIntXWeightOnlyConfig)
@@ -2234,7 +2245,7 @@ def __post_init__(self):
22342245

22352246

22362247
# for BC
2237-
fpx_weight_only = FPXWeightOnlyConfig
2248+
fpx_weight_only = _ConfigDeprecationWrapper("fpx_weight_only", FPXWeightOnlyConfig)
22382249

22392250

22402251
@register_quantize_module_handler(FPXWeightOnlyConfig)

torchao/utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from functools import reduce
1313
from importlib.metadata import version
1414
from math import gcd
15-
from typing import Any, Callable, Optional
15+
from typing import Any, Callable, Optional, Type
1616

1717
import torch
1818
import torch.nn.utils.parametrize as parametrize
@@ -432,6 +432,25 @@ def __eq__(self, other):
432432
TORCH_VERSION_AFTER_2_2 = _deprecated_torch_version_after("2.2.0.dev")
433433

434434

435+
class _ConfigDeprecationWrapper:
436+
"""
437+
A deprecation wrapper that directs users from a deprecated "config function"
438+
(e.g. `int4_weight_only`) to the replacement config class.
439+
"""
440+
441+
def __init__(self, deprecated_name: str, config_cls: Type):
442+
self.deprecated_name = deprecated_name
443+
self.config_cls = config_cls
444+
445+
def __call__(self, *args, **kwargs):
446+
warnings.warn(
447+
f"`{self.deprecated_name}` is deprecated and will be removed in a future release. "
448+
f"Please use `{self.config_cls.__name__}` instead. Example usage:\n"
449+
f" quantize_(model, {self.config_cls.__name__}(...))"
450+
)
451+
return self.config_cls(*args, **kwargs)
452+
453+
435454
"""
436455
Helper function for implementing aten op or torch function dispatch
437456
and dispatching to these implementations.

0 commit comments

Comments
 (0)