Skip to content

Commit f0812a1

Browse files
ved1betapre-commit-ci[bot]changwangssxin3he
authored
fix(pytorch): Rename layer_scale parameter to avoid quantization error (#2172)
* fix(pytorch): Rename layer_scale parameter to avoid quantization error --------- Signed-off-by: Sun, Xuehao <xuehao.sun@intel.com> Signed-off-by: changwangss <chang1.wang@intel.com> Signed-off-by: xin3he <xin3.he@intel.com> Signed-off-by: V-E-D <vedantthote2019@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Wang, Chang <chang1.wang@intel.com> Co-authored-by: Xin He <xin3.he@intel.com>
1 parent cfb38fa commit f0812a1

File tree

2 files changed

+126
-2
lines changed

2 files changed

+126
-2
lines changed

neural_compressor/adaptor/pytorch.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4170,8 +4170,12 @@ def _get_module_scale_zeropoint(self, model, tune_cfg, prefix=""):
41704170
sub_name = node.target
41714171
if not hasattr(model, node.target):
41724172
continue
4173-
if "scale" in node.target:
4174-
tune_cfg["get_attr"][sub_name] = float(getattr(model, node.target))
4173+
# Improved scale detection logic
4174+
if "scale" in node.target and not any(exclude in node.target for exclude in ["layer_scale", "gamma"]):
4175+
try:
4176+
tune_cfg["get_attr"][sub_name] = getattr(model, node.target).tolist()
4177+
except Exception as e:
4178+
logger.warning(f"Could not convert {node.target} to list, skipping... Error: {str(e)}")
41754179
elif "zero_point" in node.target:
41764180
tune_cfg["get_attr"][sub_name] = int(getattr(model, node.target))
41774181
else:
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import unittest
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
from neural_compressor import quantization
7+
from neural_compressor.config import PostTrainingQuantConfig
8+
9+
torch.manual_seed(42)
10+
11+
12+
class CalibDataloader:
13+
"""Simple calibration dataloader for testing."""
14+
15+
def __init__(self, data, label):
16+
self.data = data
17+
self.label = label
18+
self.batch_size = 1 # Since we're yielding single samples
19+
20+
def __iter__(self):
21+
yield self.data, self.label
22+
23+
24+
class ConvEncoderWithLayerScale(nn.Module):
25+
"""Test model with layer_scale parameter that caused the original issue."""
26+
27+
def __init__(self, dim=64, hidden_dim=128, kernel_size=3, drop_path=0.0, use_layer_scale=True):
28+
super().__init__()
29+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim)
30+
self.norm = nn.BatchNorm2d(dim)
31+
self.pwconv1 = nn.Conv2d(dim, hidden_dim, kernel_size=1)
32+
self.act = nn.GELU()
33+
self.pwconv2 = nn.Conv2d(hidden_dim, dim, kernel_size=1)
34+
self.drop_path = nn.Identity() if drop_path <= 0.0 else nn.Dropout(drop_path)
35+
self.use_layer_scale = use_layer_scale
36+
if use_layer_scale:
37+
self.layer_scale = nn.Parameter(torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)
38+
39+
def forward(self, x):
40+
input = x
41+
x = self.dwconv(x)
42+
x = self.norm(x)
43+
x = self.pwconv1(x)
44+
x = self.act(x)
45+
x = self.pwconv2(x)
46+
if self.use_layer_scale:
47+
x = self.layer_scale * x
48+
x = input + self.drop_path(x)
49+
return x
50+
51+
52+
class ConvEncoderWithLayerGamma(nn.Module):
53+
"""Test model with renamed layer_gamma parameter (the fix)."""
54+
55+
def __init__(self, dim=64, hidden_dim=128, kernel_size=3, drop_path=0.0, use_layer_scale=True):
56+
super().__init__()
57+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim)
58+
self.norm = nn.BatchNorm2d(dim)
59+
self.pwconv1 = nn.Conv2d(dim, hidden_dim, kernel_size=1)
60+
self.act = nn.GELU()
61+
self.pwconv2 = nn.Conv2d(hidden_dim, dim, kernel_size=1)
62+
self.drop_path = nn.Identity() if drop_path <= 0.0 else nn.Dropout(drop_path)
63+
self.use_layer_scale = use_layer_scale
64+
if use_layer_scale:
65+
self.layer_gamma = nn.Parameter(torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)
66+
67+
def forward(self, x):
68+
input = x
69+
x = self.dwconv(x)
70+
x = self.norm(x)
71+
x = self.pwconv1(x)
72+
x = self.act(x)
73+
x = self.pwconv2(x)
74+
if self.use_layer_scale:
75+
x = self.layer_gamma * x
76+
x = input + self.drop_path(x)
77+
return x
78+
79+
80+
class TestPyTorchLayerScale(unittest.TestCase):
81+
@classmethod
82+
def setUpClass(self):
83+
self.constant_data = torch.randn(1, 64, 32, 32)
84+
self.constant_label = torch.randint(0, 10, (1,))
85+
86+
def test_layer_scale_error(self):
87+
"""Test that the original layer_scale parameter causes an error."""
88+
model = ConvEncoderWithLayerScale()
89+
model.eval()
90+
91+
calib_dataloader = CalibDataloader(self.constant_data, self.constant_label)
92+
93+
# Configure quantization
94+
conf = PostTrainingQuantConfig()
95+
96+
# Try to quantize and verify it fails
97+
q_model = quantization.fit(model, conf, calib_dataloader=calib_dataloader)
98+
# The quantization should fail and return None
99+
self.assertIsNotNone(q_model, "Quantization should succeed with layer_scale parameter")
100+
101+
def test_layer_gamma_success(self):
102+
"""Test that the renamed layer_gamma parameter works correctly."""
103+
model = ConvEncoderWithLayerGamma()
104+
model.eval()
105+
106+
calib_dataloader = CalibDataloader(self.constant_data, self.constant_label)
107+
108+
# Configure quantization
109+
conf = PostTrainingQuantConfig()
110+
111+
# This should succeed with layer_gamma parameter
112+
try:
113+
q_model = quantization.fit(model, conf, calib_dataloader=calib_dataloader)
114+
self.assertIsNotNone(q_model)
115+
except ValueError as e:
116+
self.fail(f"Quantization failed with layer_gamma: {str(e)}")
117+
118+
119+
if __name__ == "__main__":
120+
unittest.main()

0 commit comments

Comments
 (0)