Skip to content

Commit 9aaf3ec

Browse files
committed
fixes for sdxl
Summary: added filtering to the api, added an api to convert conv1x1 to linear so they can be quantized. Added a fix to subclass to avoid situations where weight-only quant weight isn't contiguous Test Plan: python test/test.py Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: e22d744 Pull Request resolved: #26
1 parent 793fa01 commit 9aaf3ec

File tree

3 files changed

+54
-20
lines changed

3 files changed

+54
-20
lines changed

torchao/quantization/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"change_linear_weights_to_int8_dqtensors",
1919
"change_linear_weights_to_int8_woqtensors",
2020
"change_linear_weights_to_int4_woqtensors",
21-
"insert_subclass",
21+
"swap_conv2d_1x1_to_linear"
2222
"safe_int_mm",
2323
"dynamically_quantize_per_tensor",
2424
"quantize_activation_per_token_absmax",

torchao/quantization/quant_api.py

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"change_linear_weights_to_int8_dqtensors",
3636
"change_linear_weights_to_int8_woqtensors",
3737
"change_linear_weights_to_int4_woqtensors",
38+
"swap_conv2d_1x1_to_linear"
3839
]
3940

4041

@@ -45,19 +46,17 @@ def _replace_with_custom_fn_if_matches_filter(
4546
For each `child` in `model`, replaces it with `replacement_fn(child)`
4647
if `filter_fn(child)` is `True`
4748
"""
48-
name_to_child = dict(model.named_children())
49-
for name, child in name_to_child.items():
50-
if cur_fqn == "":
51-
new_fqn = name
52-
else:
53-
new_fqn = f"{cur_fqn}.{name}"
54-
if filter_fn(child, new_fqn):
55-
new_child = replacement_fn(child)
56-
setattr(model, name, new_child)
57-
else:
58-
_replace_with_custom_fn_if_matches_filter(
59-
child, replacement_fn, filter_fn, new_fqn
49+
if filter_fn(model, cur_fqn[:-1]):
50+
model = replacement_fn(model)
51+
return model
52+
else:
53+
for name, child in model.named_children():
54+
new_child = _replace_with_custom_fn_if_matches_filter(
55+
child, replacement_fn, filter_fn, f"{cur_fqn}{name}."
6056
)
57+
setattr(model, name, new_child)
58+
return model
59+
6160

6261
def _is_linear(mod, *args):
6362
return (
@@ -81,7 +80,7 @@ def apply_weight_only_int8_quant(model):
8180
)
8281

8382

84-
def apply_dynamic_quant(model):
83+
def apply_dynamic_quant(model, filter_fn=None):
8584
"""
8685
Applies dynamic symmetric per-token activation and per-channel weight
8786
quantization to all linear layers in the given model using
@@ -90,7 +89,7 @@ def apply_dynamic_quant(model):
9089
_replace_with_custom_fn_if_matches_filter(
9190
model,
9291
lambda mod: DynamicallyPerAxisQuantizedLinear.from_float(mod),
93-
_is_linear,
92+
_is_linear if filter_fn is None else filter_fn,
9493
)
9594

9695

@@ -104,18 +103,23 @@ def insert_subclass(lin):
104103
return insert_subclass
105104

106105

107-
def change_linear_weights_to_int8_dqtensors(model):
106+
def change_linear_weights_to_int8_dqtensors(model, filter_fn=None):
108107
"""
109108
Converts all linear weight tensors to the `Int8DynamicallyQuantizedLinearWeight`
110109
Tensor subclass, effectively applying the same form of quantization
111110
as apply_dynamic_quant while not modifying the linear modules.
112111
"""
112+
if filter_fn is None:
113+
filter_fn = (
114+
lambda *args:
115+
_is_linear(*args) and
116+
_in_features_greater_than_16(*args)
117+
)
118+
113119
_replace_with_custom_fn_if_matches_filter(
114120
model,
115121
_get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight),
116-
lambda *args:
117-
_is_linear(*args) and
118-
_in_features_greater_than_16(*args)
122+
filter_fn
119123
)
120124

121125

@@ -140,8 +144,36 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs):
140144
effectively applying the same form of quantization
141145
as apply_dynamic_quant while not modifying the linear modules.
142146
"""
147+
filter_fn = kwargs.pop("filter_fn", _is_linear)
148+
143149
_replace_with_custom_fn_if_matches_filter(
144150
model,
145151
_get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, **kwargs),
146-
_is_linear,
152+
filter_fn,
153+
)
154+
155+
def swap_conv2d_1x1_to_linear(model):
156+
"""
157+
Changes all conv2d 1x1 modules to equivalent linear modules so that they can then be quantized.
158+
"""
159+
class PermuteSandwich(torch.nn.Module):
160+
def __init__(self, mod):
161+
super().__init__()
162+
self.mod = mod
163+
164+
def forward(self, *args):
165+
return self.mod(args[0].permute(0, 2, 3, 1)).permute(-0,3,1,2)
166+
167+
168+
def replace_conv2d_1x1(conv):
169+
assert conv.kernel_size == (1, 1)
170+
lin = torch.nn.Linear(conv.in_channels, conv.out_channels, bias=(conv.bias is None))
171+
lin.weight=torch.nn.Parameter(conv.weight.squeeze(-1,-2))
172+
lin.bias = conv.bias
173+
return PermuteSandwich(lin)
174+
175+
_replace_with_custom_fn_if_matches_filter(
176+
model,
177+
replace_conv2d_1x1,
178+
filter_fn=lambda mod, *args: isinstance(mod, torch.nn.Conv2d) and mod.kernel_size==(1,1)
147179
)

torchao/quantization/subclass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,8 @@ def from_float(cls, input_float, qmin=-128, qmax=127):
269269
# however the external representation of our tensor will maintain the correct
270270
# shape attribute which needs to be tracked directly.
271271
int_data = w_int_repr.contiguous().t()
272+
if cls is not Int8DynamicallyQuantizedLinearWeight:
273+
int_data = int_data.contiguous()
272274
return cls(
273275
int_data, w_scales, False, input_float.shape, dtype=input_float.dtype
274276
)

0 commit comments

Comments
 (0)