Skip to content

Commit b56f51d

Browse files
committed
Subclass fixes for torchbench
Summary: pytorch update changed flatten, now updated. Handle a number of behaviors needed to work with torchbench dynamo userbenchmark. Removed transpose and detach code and broke them into _change_shape and _apply_fn_to_data methods, changed subclasses to overwrite torch_function for torch.nn.functional.linear rather than mm and addmm individually since linear can also hit expand, view, bmm...etc. Test Plan: python test/test.py Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: a5c0146 Pull Request resolved: #24
1 parent 1b92d57 commit b56f51d

File tree

4 files changed

+112
-69
lines changed

4 files changed

+112
-69
lines changed

test/test.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -832,11 +832,12 @@ def test_dequantize_int4_weight_only_quant_subclass(self):
832832
for groupsize in [256, 128]:
833833
for inner_k_tiles in [8, 2]:
834834
for m in [1, 256]:
835-
self._test_dequantize_impl(
836-
lambda w: Int4WeightOnlyQuantizedLinearWeight.from_float(w, groupsize, inner_k_tiles),
837-
15,
838-
test_shape=[m, 256, 8]
839-
)
835+
for n in [8, 13]:
836+
self._test_dequantize_impl(
837+
lambda w: Int4WeightOnlyQuantizedLinearWeight.from_float(w, groupsize, inner_k_tiles),
838+
15,
839+
test_shape=[m, 256, n]
840+
)
840841

841842
def _test_lin_weight_subclass_impl(
842843
self,
@@ -886,11 +887,12 @@ def test_int4_weight_only_quant_subclass(self):
886887
for groupsize in [128, 64]:
887888
for inner_k_tiles in [4, 2]:
888889
for m in [1, 256]:
889-
self._test_lin_weight_subclass_impl(
890-
lambda w: Int4WeightOnlyQuantizedLinearWeight.from_float(w, groupsize, inner_k_tiles),
891-
10,
892-
test_shape=[m, 256, 8]
893-
)
890+
for n in [8, 13]:
891+
self._test_lin_weight_subclass_impl(
892+
lambda w: Int4WeightOnlyQuantizedLinearWeight.from_float(w, groupsize, inner_k_tiles),
893+
10,
894+
test_shape=[m, 256, n]
895+
)
894896

895897
@torch.no_grad()
896898
def _test_lin_weight_subclass_api_impl(

torchao/quantization/quant_api.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
DynamicallyPerAxisQuantizedLinear,
2121
)
2222
from .subclass import (
23+
QuantizedLinearWeightBase,
2324
Int8DynamicallyQuantizedLinearWeight,
2425
Int8WeightOnlyQuantizedLinearWeight,
2526
Int4WeightOnlyQuantizedLinearWeight,
@@ -58,6 +59,15 @@ def _replace_with_custom_fn_if_matches_filter(
5859
child, replacement_fn, filter_fn, new_fqn
5960
)
6061

62+
def _is_linear(mod, *args):
63+
return (
64+
isinstance(mod, torch.nn.Linear) and
65+
hasattr(mod, "weight") and
66+
not isinstance(mod.weight, QuantizedLinearWeightBase)
67+
)
68+
69+
def _in_features_greater_than_16(mod, *args):
70+
return hasattr(mod, "in_features") and mod.in_features > 16
6171

6272
def apply_weight_only_int8_quant(model):
6373
"""
@@ -67,7 +77,7 @@ def apply_weight_only_int8_quant(model):
6777
_replace_with_custom_fn_if_matches_filter(
6878
model,
6979
WeightOnlyInt8QuantLinear.from_float,
70-
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
80+
_is_linear,
7181
)
7282

7383

@@ -80,7 +90,7 @@ def apply_dynamic_quant(model):
8090
_replace_with_custom_fn_if_matches_filter(
8191
model,
8292
lambda mod: DynamicallyPerAxisQuantizedLinear.from_float(mod),
83-
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
93+
_is_linear,
8494
)
8595

8696

@@ -103,7 +113,9 @@ def change_linear_weights_to_int8_dqtensors(model):
103113
_replace_with_custom_fn_if_matches_filter(
104114
model,
105115
_get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight),
106-
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
116+
lambda *args:
117+
_is_linear(*args) and
118+
_in_features_greater_than_16(*args)
107119
)
108120

109121

@@ -117,7 +129,7 @@ def change_linear_weights_to_int8_woqtensors(model):
117129
_replace_with_custom_fn_if_matches_filter(
118130
model,
119131
_get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight),
120-
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
132+
_is_linear,
121133
)
122134

123135

@@ -131,5 +143,5 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs):
131143
_replace_with_custom_fn_if_matches_filter(
132144
model,
133145
_get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, **kwargs),
134-
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
146+
_is_linear,
135147
)

torchao/quantization/quant_primitives.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,6 @@ def quant_int8_per_token_matmul(
351351
assert (
352352
w_vals_int8_t.dtype == torch.int8
353353
), f"w dtype {w_vals_int8_t.dtype} not yet supported"
354-
assert (
355-
w_scales.dtype == output_dtype
356-
), f"{w_scales.dtype} does not match {output_dtype}"
357354

358355
#
359356
# 1. do the matrix form of dot(X_i, W_j)
@@ -375,8 +372,8 @@ def quant_int8_per_token_matmul(
375372
torch.bfloat16,
376373
], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}"
377374

378-
y = (y_dot_int32 * x_scales.view(-1, 1) * w_scales).reshape(
379-
*x_vals_int8.shape[:-1], -1
375+
y = (y_dot_int32 * x_scales.reshape(-1, 1) * w_scales).reshape(
376+
*x_vals_int8.shape[:-1], y_dot_int32.shape[-1]
380377
)
381378

382379
# can downcast only at the very end

0 commit comments

Comments
 (0)