Skip to content

Commit 49694e3

Browse files
authored
migrate prototype/awq to configs (#1853)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent a91275a commit 49694e3

File tree

1 file changed

+79
-71
lines changed

1 file changed

+79
-71
lines changed

torchao/prototype/awq/api.py

Lines changed: 79 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,28 @@
1+
import types
2+
from dataclasses import dataclass
3+
14
import torch
25

6+
from torchao.core.config import AOBaseConfig
37
from torchao.dtypes import (
48
TensorCoreTiledLayout,
59
to_affine_quantized_intx,
610
)
711
from torchao.dtypes.uintx.uintx_layout import _DTYPE_TO_BIT_WIDTH, UintxLayout
812
from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata
913
from torchao.quantization.granularity import PerGroup
10-
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
14+
from torchao.quantization.quant_api import (
15+
_linear_extra_repr,
16+
_replace_with_custom_fn_if_matches_filter,
17+
)
1118
from torchao.quantization.quant_primitives import (
1219
_DTYPE_TO_QVALUE_BOUNDS,
1320
MappingType,
1421
ZeroPointDomain,
1522
)
23+
from torchao.quantization.transform_module import (
24+
register_quantize_module_handler,
25+
)
1626

1727
from .core import (
1828
AWQObservedLinear,
@@ -82,88 +92,86 @@ def replace_with_observer(layer):
8292
_replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear)
8393

8494

85-
def _observed_linear_subclass_inserter(constructor):
95+
@dataclass
96+
class AWQUIntXConfig(AOBaseConfig):
8697
"""
87-
Replaces unquantized AWQObservedLinear instances with quantized linear instances.
98+
Configuration for quantizing linear layers when passed into quantize_()
8899
89100
Args:
90-
constructor: the function which applies quantization to the AWQObservedLinear layer
101+
quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8
102+
group_size: Quantization granularity. Use -1 for channel wise quantization
103+
weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used
91104
"""
92105

93-
def insert_subclass(observed_linear):
94-
# creates the new linear layer using constructor
95-
linear = torch.nn.Linear(
96-
observed_linear.in_features,
97-
observed_linear.out_features,
98-
observed_linear.bias != None,
99-
device=observed_linear.weight.device,
100-
dtype=observed_linear.weight.dtype,
101-
)
102-
linear.weight = torch.nn.Parameter(
103-
constructor(observed_linear), requires_grad=False
104-
)
105-
linear.bias = observed_linear.bias
106-
return linear
106+
quant_dtype: torch.dtype = torch.uint4
107+
group_size: int = 64
108+
use_hqq: bool = False
107109

108-
return insert_subclass
109110

111+
# for bc
112+
awq_uintx = AWQUIntXConfig
110113

111-
def awq_uintx(
112-
quant_dtype: torch.dtype = torch.uint4,
113-
group_size: int = 64,
114-
use_hqq: bool = False,
115-
):
116-
"""
117-
Quantizes linear layers when passed into quantize_()
118114

119-
Args:
120-
quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8
121-
group_size: Quantization granularity. Use -1 for channel wise quantization
122-
weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used
123-
"""
115+
@register_quantize_module_handler(AWQUIntXConfig)
116+
def _awq_uintx_transform(
117+
module: torch.nn.Module,
118+
config: AWQUIntXConfig,
119+
) -> torch.nn.Module:
120+
quant_dtype = config.quant_dtype
121+
group_size = config.group_size
122+
use_hqq = config.use_hqq
123+
observed_linear = module
124+
124125
assert (
125126
quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8
126127
), "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8"
127128

128-
def weight_quant_func(observed_linear):
129-
equalization_scale = observed_linear.act_obs.calculate_qparams()
130-
# AQT config
131-
if quant_dtype == torch.uint4:
132-
target_dtype = torch.int32
133-
eps = 1e-6
134-
preserve_zero = False
135-
zero_point_dtype = torch.bfloat16
136-
zero_point_domain = ZeroPointDomain.FLOAT
137-
_layout = TensorCoreTiledLayout(inner_k_tiles=8)
138-
else:
139-
target_dtype = torch.uint8
140-
eps = torch.finfo(torch.float32).eps
141-
preserve_zero = True
142-
zero_point_dtype = torch.int64
143-
zero_point_domain = ZeroPointDomain.INT
144-
_layout = UintxLayout(quant_dtype)
145-
146-
mapping_type = MappingType.ASYMMETRIC
147-
block_size = (1, group_size)
148-
quant_min = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][0]
149-
quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][1]
150-
qw = to_affine_quantized_intx(
151-
observed_linear.weight * equalization_scale,
152-
mapping_type,
153-
block_size,
154-
target_dtype,
155-
quant_min,
156-
quant_max,
157-
eps,
158-
zero_point_dtype=zero_point_dtype,
159-
preserve_zero=preserve_zero,
160-
zero_point_domain=zero_point_domain,
161-
_layout=_layout,
162-
use_hqq=use_hqq,
163-
)
129+
equalization_scale = observed_linear.act_obs.calculate_qparams()
130+
# AQT config
131+
if quant_dtype == torch.uint4:
132+
target_dtype = torch.int32
133+
eps = 1e-6
134+
preserve_zero = False
135+
zero_point_dtype = torch.bfloat16
136+
zero_point_domain = ZeroPointDomain.FLOAT
137+
_layout = TensorCoreTiledLayout(inner_k_tiles=8)
138+
else:
139+
target_dtype = torch.uint8
140+
eps = torch.finfo(torch.float32).eps
141+
preserve_zero = True
142+
zero_point_dtype = torch.int64
143+
zero_point_domain = ZeroPointDomain.INT
144+
_layout = UintxLayout(quant_dtype)
164145

165-
return to_weight_tensor_with_linear_activation_scale_metadata(
166-
qw, equalization_scale
167-
)
146+
mapping_type = MappingType.ASYMMETRIC
147+
block_size = (1, group_size)
148+
quant_min = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][0]
149+
quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][1]
150+
qw = to_affine_quantized_intx(
151+
observed_linear.weight * equalization_scale,
152+
mapping_type,
153+
block_size,
154+
target_dtype,
155+
quant_min,
156+
quant_max,
157+
eps,
158+
zero_point_dtype=zero_point_dtype,
159+
preserve_zero=preserve_zero,
160+
zero_point_domain=zero_point_domain,
161+
_layout=_layout,
162+
use_hqq=use_hqq,
163+
)
164+
165+
qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale)
168166

169-
return _observed_linear_subclass_inserter(weight_quant_func)
167+
linear = torch.nn.Linear(
168+
observed_linear.in_features,
169+
observed_linear.out_features,
170+
observed_linear.bias != None,
171+
device=observed_linear.weight.device,
172+
dtype=observed_linear.weight.dtype,
173+
)
174+
linear.weight = torch.nn.Parameter(qw, requires_grad=False)
175+
linear.extra_repr = types.MethodType(_linear_extra_repr, module)
176+
linear.bias = observed_linear.bias
177+
return linear

0 commit comments

Comments
 (0)