Skip to content

Commit e670872

Browse files
committed
smoothquant fixes
Summary: certain custom linear modules add additional inputs to the forward that need to be handled but can be otherwise ignored. Additionally swap_linear_with_smooth_fq_linear had a bug where linear subclasses would get past the if statement and error on the dict key lookup since the actual class wasn't expected. (#30) enabled NonDynamicallyQuantizableLinear to work with smoothquant and fixed bug for other subclasses. At some point this should be brought in line with the other APIs if its getting use. Test Plan: python test/test.py Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 059b392 Pull Request resolved: #28
1 parent 54bcd5a commit e670872

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

test/test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,6 +1124,23 @@ def test_shape_logger(self):
11241124

11251125

11261126
class SmoothquantIntegrationTest(unittest.TestCase):
1127+
@torch.no_grad()
1128+
def test_non_dynamically_quantizable_linear(self):
1129+
model = torch.nn.Sequential(
1130+
torch.nn.modules.linear.NonDynamicallyQuantizableLinear(32,32),
1131+
torch.nn.ReLU()
1132+
).to("cuda").to(torch.bfloat16)
1133+
example_input = torch.randn(32,32, device="cuda", dtype=torch.bfloat16)
1134+
ref = model(example_input)
1135+
swap_linear_with_smooth_fq_linear(model)
1136+
model(ref)
1137+
smooth_fq_linear_to_inference(model)
1138+
model_c = torch.compile(model, mode="max-autotune")
1139+
out = model_c(example_input)
1140+
sqnr = SQNR(ref, out)
1141+
self.assertTrue(sqnr >= 25)
1142+
self.assertTrue(isinstance(model[0], SmoothFakeDynamicallyQuantizedLinear))
1143+
11271144
@torch.inference_mode()
11281145
def test_on_dummy_distilbert(self):
11291146
# https://huggingface.co/distilbert-base-uncased#how-to-use

torchao/quantization/smoothquant.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def __init__(self, *args, **kwargs):
137137
super().__init__(*args, **kwargs)
138138
self.init_smoothquant_variables(alpha)
139139

140-
def forward(self, X):
140+
def forward(self, X, *args, **kwargs):
141141
if self.calibrating:
142142
self.update_x_running_abs_max(X)
143143
Y = F.linear(X, self.weight, self.bias)
@@ -199,6 +199,7 @@ def set_debug_x_absmax(self):
199199

200200
source_cls_to_target_cls = {
201201
torch.nn.Linear: SmoothFakeDynamicallyQuantizedLinear,
202+
torch.nn.modules.linear.NonDynamicallyQuantizableLinear: SmoothFakeDynamicallyQuantizedLinear,
202203
}
203204

204205

@@ -212,8 +213,8 @@ def swap_linear_with_smooth_fq_linear(
212213
new_fqn = name
213214
else:
214215
new_fqn = f"{cur_fqn}.{name}"
215-
if ((skip_fqn_list is None) or (new_fqn not in skip_fqn_list)) and isinstance(
216-
child, tuple(source_cls_to_target_cls.keys())
216+
if ((skip_fqn_list is None) or (new_fqn not in skip_fqn_list)) and (
217+
type(child) in source_cls_to_target_cls.keys()
217218
):
218219
target_cls = source_cls_to_target_cls[type(child)]
219220
new_child = target_cls.from_float(child, alpha=alpha)

0 commit comments

Comments
 (0)