Skip to content

Commit 64dccf3

Browse files
committed
Deprecate config functions like int4_weight_only
**Summary:** These have been superseded by `AOBaseConfig` objects for several releases already, but we never deprecated them. We will keep them around for another release before breaking BC and removing them. **Test Plan:** ``` python test/quantization/test_quant_api.py -k test_config_deprecation ``` ghstack-source-id: 9b2f9a3 Pull Request resolved: #2994
1 parent 10ba659 commit 64dccf3

File tree

3 files changed

+99
-12
lines changed

3 files changed

+99
-12
lines changed

test/quantization/test_quant_api.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import gc
1111
import tempfile
1212
import unittest
13+
import warnings
1314
from pathlib import Path
1415

1516
import torch
@@ -752,6 +753,56 @@ def test_int4wo_cuda_serialization(self):
752753
# load state_dict in cuda
753754
model.load_state_dict(sd, assign=True)
754755

756+
def test_config_deprecation(self):
757+
"""
758+
Test that old config functions like `int4_weight_only` trigger deprecation warnings.
759+
"""
760+
from torchao.quantization import (
761+
float8_dynamic_activation_float8_weight,
762+
float8_static_activation_float8_weight,
763+
float8_weight_only,
764+
fpx_weight_only,
765+
gemlite_uintx_weight_only,
766+
int4_dynamic_activation_int4_weight,
767+
int4_weight_only,
768+
int8_dynamic_activation_int4_weight,
769+
int8_dynamic_activation_int8_weight,
770+
int8_weight_only,
771+
uintx_weight_only,
772+
)
773+
774+
# Reset deprecation warning state, otherwise we won't log warnings here
775+
warnings.resetwarnings()
776+
777+
# Map from deprecated API to the args needed to instantiate it
778+
deprecated_apis_to_args = {
779+
float8_dynamic_activation_float8_weight: (),
780+
float8_static_activation_float8_weight: (torch.randn(3)),
781+
float8_weight_only: (),
782+
fpx_weight_only: (3, 2),
783+
gemlite_uintx_weight_only: (),
784+
int4_dynamic_activation_int4_weight: (),
785+
int4_weight_only: (),
786+
int8_dynamic_activation_int4_weight: (),
787+
int8_dynamic_activation_int8_weight: (),
788+
int8_weight_only: (),
789+
uintx_weight_only: (torch.uint4,),
790+
}
791+
792+
with warnings.catch_warnings(record=True) as _warnings:
793+
# Call each deprecated API twice
794+
for cls, args in deprecated_apis_to_args.items():
795+
cls(*args)
796+
cls(*args)
797+
798+
# Each call should trigger the warning only once
799+
self.assertEqual(len(_warnings), len(deprecated_apis_to_args))
800+
for w in _warnings:
801+
self.assertIn(
802+
"is deprecated and will be removed in a future release",
803+
str(w.message),
804+
)
805+
755806

756807
common_utils.instantiate_parametrized_tests(TestQuantFlow)
757808

torchao/quantization/quant_api.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
to_weight_tensor_with_linear_activation_quantization_metadata,
9393
)
9494
from torchao.utils import (
95+
_ConfigDeprecationWrapper,
9596
_is_fbgemm_genai_gpu_available,
9697
is_MI300,
9798
is_sm_at_least_89,
@@ -639,7 +640,9 @@ def __post_init__(self):
639640

640641

641642
# for BC
642-
int8_dynamic_activation_int4_weight = Int8DynamicActivationInt4WeightConfig
643+
int8_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
644+
"int8_dynamic_activation_int4_weight", Int8DynamicActivationInt4WeightConfig
645+
)
643646

644647

645648
@register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig)
@@ -957,7 +960,9 @@ def __post_init__(self):
957960

958961

959962
# for bc
960-
int4_dynamic_activation_int4_weight = Int4DynamicActivationInt4WeightConfig
963+
int4_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
964+
"int4_dynamic_activation_int4_weight", Int4DynamicActivationInt4WeightConfig
965+
)
961966

962967

963968
@register_quantize_module_handler(Int4DynamicActivationInt4WeightConfig)
@@ -1018,7 +1023,9 @@ def __post_init__(self):
10181023

10191024

10201025
# for BC
1021-
gemlite_uintx_weight_only = GemliteUIntXWeightOnlyConfig
1026+
gemlite_uintx_weight_only = _ConfigDeprecationWrapper(
1027+
"gemlite_uintx_weight_only", GemliteUIntXWeightOnlyConfig
1028+
)
10221029

10231030

10241031
@register_quantize_module_handler(GemliteUIntXWeightOnlyConfig)
@@ -1100,7 +1107,7 @@ def __post_init__(self):
11001107

11011108
# for BC
11021109
# TODO maybe change other callsites
1103-
int4_weight_only = Int4WeightOnlyConfig
1110+
int4_weight_only = _ConfigDeprecationWrapper("int4_weight_only", Int4WeightOnlyConfig)
11041111

11051112

11061113
def _int4_weight_only_quantize_tensor(weight, config):
@@ -1310,7 +1317,7 @@ def __post_init__(self):
13101317

13111318

13121319
# for BC
1313-
int8_weight_only = Int8WeightOnlyConfig
1320+
int8_weight_only = _ConfigDeprecationWrapper("int8_weight_only", Int8WeightOnlyConfig)
13141321

13151322

13161323
def _int8_weight_only_quantize_tensor(weight, config):
@@ -1471,7 +1478,9 @@ def __post_init__(self):
14711478

14721479

14731480
# for BC
1474-
int8_dynamic_activation_int8_weight = Int8DynamicActivationInt8WeightConfig
1481+
int8_dynamic_activation_int8_weight = _ConfigDeprecationWrapper(
1482+
"int8_dynamic_activation_int8_weight", Int8DynamicActivationInt8WeightConfig
1483+
)
14751484

14761485

14771486
def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
@@ -1580,7 +1589,9 @@ def __post_init__(self):
15801589

15811590

15821591
# for BC
1583-
float8_weight_only = Float8WeightOnlyConfig
1592+
float8_weight_only = _ConfigDeprecationWrapper(
1593+
"float8_weight_only", Float8WeightOnlyConfig
1594+
)
15841595

15851596

15861597
def _float8_weight_only_quant_tensor(weight, config):
@@ -1738,7 +1749,9 @@ def __post_init__(self):
17381749

17391750

17401751
# for bc
1741-
float8_dynamic_activation_float8_weight = Float8DynamicActivationFloat8WeightConfig
1752+
float8_dynamic_activation_float8_weight = _ConfigDeprecationWrapper(
1753+
"float8_dynamic_activation_float8_weight", Float8DynamicActivationFloat8WeightConfig
1754+
)
17421755

17431756

17441757
def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
@@ -1911,7 +1924,9 @@ def __post_init__(self):
19111924

19121925

19131926
# for bc
1914-
float8_static_activation_float8_weight = Float8StaticActivationFloat8WeightConfig
1927+
float8_static_activation_float8_weight = _ConfigDeprecationWrapper(
1928+
"float8_static_activation_float8_weight", Float8StaticActivationFloat8WeightConfig
1929+
)
19151930

19161931

19171932
@register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig)
@@ -1994,7 +2009,9 @@ def __post_init__(self):
19942009

19952010

19962011
# for BC
1997-
uintx_weight_only = UIntXWeightOnlyConfig
2012+
uintx_weight_only = _ConfigDeprecationWrapper(
2013+
"uintx_weight_only", UIntXWeightOnlyConfig
2014+
)
19982015

19992016

20002017
@register_quantize_module_handler(UIntXWeightOnlyConfig)
@@ -2234,7 +2251,7 @@ def __post_init__(self):
22342251

22352252

22362253
# for BC
2237-
fpx_weight_only = FPXWeightOnlyConfig
2254+
fpx_weight_only = _ConfigDeprecationWrapper("fpx_weight_only", FPXWeightOnlyConfig)
22382255

22392256

22402257
@register_quantize_module_handler(FPXWeightOnlyConfig)

torchao/utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from functools import reduce
1313
from importlib.metadata import version
1414
from math import gcd
15-
from typing import Any, Callable, Optional
15+
from typing import Any, Callable, Optional, Type
1616

1717
import torch
1818
import torch.nn.utils.parametrize as parametrize
@@ -432,6 +432,25 @@ def __eq__(self, other):
432432
TORCH_VERSION_AFTER_2_2 = _deprecated_torch_version_after("2.2.0.dev")
433433

434434

435+
class _ConfigDeprecationWrapper:
436+
"""
437+
A deprecation wrapper that directs users from a deprecated "config function"
438+
(e.g. `int4_weight_only`) to the replacement config class.
439+
"""
440+
441+
def __init__(self, deprecated_name: str, config_cls: Type):
442+
self.deprecated_name = deprecated_name
443+
self.config_cls = config_cls
444+
445+
def __call__(self, *args, **kwargs):
446+
warnings.warn(
447+
f"`{self.deprecated_name}` is deprecated and will be removed in a future release. "
448+
f"Please use `{self.config_cls.__name__}` instead. Example usage:\n"
449+
f" quantize_(model, {self.config_cls.__name__}(...))"
450+
)
451+
return self.config_cls(*args, **kwargs)
452+
453+
435454
"""
436455
Helper function for implementing aten op or torch function dispatch
437456
and dispatching to these implementations.

0 commit comments

Comments
 (0)