Skip to content

Commit 54bcd5a

Browse files
committed
Making subclass BC with torch version
Summary: there was a change to torch_flatten and unflatten in pytorch core, this change allows subclass to work with both Test Plan: python test/test.py on branch cut 2.2 and main Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: afa090f Pull Request resolved: #27
1 parent 9aaf3ec commit 54bcd5a

File tree

4 files changed

+21
-16
lines changed

4 files changed

+21
-16
lines changed

torchao/quantization/dynamic_quant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(
2929
) -> None:
3030
super().__init__(in_features, out_features, bias)
3131

32-
def forward(self, X: torch.Tensor) -> torch.Tensor:
32+
def forward(self, X: torch.Tensor, *args, **kwargs) -> torch.Tensor:
3333
"""
3434
Performs the forward pass of the quantized linear layer which consists
3535
of int8 dynamic symmetric per-token activation and int8 symmetric per-channel weight

torchao/quantization/quant_api.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def _replace_with_custom_fn_if_matches_filter(
5454
new_child = _replace_with_custom_fn_if_matches_filter(
5555
child, replacement_fn, filter_fn, f"{cur_fqn}{name}."
5656
)
57-
setattr(model, name, new_child)
57+
if new_child is not child:
58+
setattr(model, name, new_child)
5859
return model
5960

6061

@@ -68,15 +69,15 @@ def _is_linear(mod, *args):
6869
def _in_features_greater_than_16(mod, *args):
6970
return hasattr(mod, "in_features") and mod.in_features > 16
7071

71-
def apply_weight_only_int8_quant(model):
72+
def apply_weight_only_int8_quant(model, filter_fn=None):
7273
"""
7374
Applies weight-only symmetric per-channel int8 quantization to all linear layers
7475
in the given model using module swaps.
7576
"""
7677
_replace_with_custom_fn_if_matches_filter(
7778
model,
7879
WeightOnlyInt8QuantLinear.from_float,
79-
_is_linear,
80+
_is_linear if filter_fn is None else filter_fn,
8081
)
8182

8283

@@ -123,7 +124,7 @@ def change_linear_weights_to_int8_dqtensors(model, filter_fn=None):
123124
)
124125

125126

126-
def change_linear_weights_to_int8_woqtensors(model):
127+
def change_linear_weights_to_int8_woqtensors(model, filter_fn=None):
127128
"""
128129
Converts all linear weight tensors to the
129130
`Int8WeightOnlyQuantizedLinearWeight` tensor subclass,
@@ -133,7 +134,7 @@ def change_linear_weights_to_int8_woqtensors(model):
133134
_replace_with_custom_fn_if_matches_filter(
134135
model,
135136
_get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight),
136-
_is_linear,
137+
_is_linear if filter_fn is None else filter_fn,
137138
)
138139

139140

@@ -152,7 +153,7 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs):
152153
filter_fn,
153154
)
154155

155-
def swap_conv2d_1x1_to_linear(model):
156+
def swap_conv2d_1x1_to_linear(model, filter_fn=None):
156157
"""
157158
Changes all conv2d 1x1 modules to equivalent linear modules so that they can then be quantized.
158159
"""
@@ -172,8 +173,11 @@ def replace_conv2d_1x1(conv):
172173
lin.bias = conv.bias
173174
return PermuteSandwich(lin)
174175

176+
if filter_fn is None:
177+
filter_fn=lambda mod, *args: isinstance(mod, torch.nn.Conv2d) and mod.kernel_size==(1,1)
178+
175179
_replace_with_custom_fn_if_matches_filter(
176180
model,
177181
replace_conv2d_1x1,
178-
filter_fn=lambda mod, *args: isinstance(mod, torch.nn.Conv2d) and mod.kernel_size==(1,1)
182+
filter_fn=filter_fn
179183
)

torchao/quantization/subclass.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -240,13 +240,13 @@ def _change_shape(self, shape):
240240
)
241241

242242
def __tensor_flatten__(self):
243-
return ["int_data", "q_scales"], [self.transposed, self.dtype]
243+
return ["int_data", "q_scales"], [self.transposed, self.dtype, self.shape]
244244

245245
@classmethod
246-
def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride):
246+
def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None):
247247
int_data, q_scales = tensor_data_dict["int_data"], tensor_data_dict["q_scales"]
248-
transposed, dtype = tensor_attributes
249-
return cls(int_data, q_scales, transposed, outer_size, dtype=dtype, strides=outer_stride)
248+
transposed, dtype, shape = tensor_attributes
249+
return cls(int_data, q_scales, transposed, shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride)
250250

251251
@classmethod
252252
def from_float(cls, input_float, qmin=-128, qmax=127):
@@ -416,20 +416,21 @@ def __tensor_flatten__(self):
416416
self.groupsize,
417417
self.inner_k_tiles,
418418
self.dtype,
419+
self.shape
419420
)
420421

421422
@classmethod
422-
def __tensor_unflatten__(cls, tensor_data_dict, attributes, outer_size, outer_stride):
423+
def __tensor_unflatten__(cls, tensor_data_dict, attributes, outer_size=None, outer_stride=None):
423424
int_data, scales_and_zeros = (
424425
tensor_data_dict["int_data"],
425426
tensor_data_dict["scales_and_zeros"],
426427
)
427-
transposed, groupsize, inner_k_tiles, dtype = attributes
428+
transposed, groupsize, inner_k_tiles, dtype, shape = attributes
428429
return cls(
429430
int_data,
430431
scales_and_zeros,
431432
transposed,
432-
outer_size,
433+
shape if outer_size is None else outer_size,
433434
groupsize,
434435
inner_k_tiles,
435436
dtype=dtype,

torchao/quantization/weight_only.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, *args, **kwargs):
2424
self.w_int8 = w_int8
2525
self.scales = scales
2626

27-
def forward(self, x):
27+
def forward(self, x, *args, **kwargs):
2828
"""
2929
Performs the forward pass of the quantized linear layer which consists
3030
ofmixed dtype matmul using int8 symmetric per-channel weight quantization

0 commit comments

Comments
 (0)