Skip to content

Commit 33d57af

Browse files
authored
Fixes observer attachment to model based on config for wanda sparsifier (#1265)
* Fixes observer attachment to model based on config for wanda sparsifier * handles case when no config is specified * Added test case in test_wanda.py for custom config * lint fix
1 parent a03ca99 commit 33d57af

File tree

2 files changed

+55
-4
lines changed

2 files changed

+55
-4
lines changed

test/sparsity/test_wanda.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,39 @@ def test_two_layer_mlp_unstructured(self):
113113

114114
sparsifier.squash_mask()
115115

116+
def test_two_layer_mlp_unstructured_custom_config(self):
117+
model = nn.Sequential(
118+
nn.Linear(128, 200), nn.ReLU(), nn.Linear(200, 10)
119+
) # C_in by C_out
120+
X1 = torch.randn(100, 128) # B1 by C_in
121+
X2 = torch.randn(50, 128) # B2 by C_in
122+
123+
# Define custom config to sparsify only the first Linear layer for testing
124+
config = [{"tensor_fqn": "0.weight"}]
125+
126+
sparsifier = WandaSparsifier(sparsity_level=0.5)
127+
sparsifier.prepare(model, config=config)
128+
129+
model(X1)
130+
model(X2)
131+
sparsifier.step()
132+
133+
cnt = 0
134+
for m in model.modules():
135+
if isinstance(m, nn.Linear):
136+
cnt += 1
137+
sparsity_level = (m.weight == 0).float().mean()
138+
if cnt == 1: # First Linear layer should have 50% sparsity
139+
assert (
140+
sparsity_level == 0.5
141+
), f"sparsity for linear layer {cnt} should be 0.5"
142+
else: # Other layers should not be sparsified
143+
assert (
144+
sparsity_level != 0.5
145+
), f"sparsity for linear layer {cnt} should not be 0.5"
146+
147+
sparsifier.squash_mask()
148+
116149

117150
if __name__ == "__main__":
118151
unittest.main()

torchao/sparsity/wanda.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55
from torch import nn
6-
from torch.ao.pruning import BaseSparsifier
6+
from torch.ao.pruning import BaseSparsifier, get_arg_info_from_tensor_fqn
77
from torch.ao.quantization import QConfig, default_placeholder_observer
88
from torch.ao.quantization.quantize import _remove_qconfig
99

@@ -47,9 +47,27 @@ def __init__(
4747
def prepare(self, model: nn.Module, config: List[Dict]) -> None:
4848
# activation: use PerChannelNormObserver
4949
# use no-op placeholder weight observer
50-
model.qconfig = QConfig(
51-
activation=PerChannelNormObserver, weight=default_placeholder_observer
52-
) # type: ignore[assignment]
50+
if config is None:
51+
# If no config is provided, apply the qconfig to the entire model
52+
model.qconfig = QConfig(
53+
activation=PerChannelNormObserver, weight=default_placeholder_observer
54+
) # type: ignore[assignment]
55+
else:
56+
for module_config in config:
57+
tensor_fqn = module_config.get("tensor_fqn", None)
58+
if tensor_fqn is None:
59+
raise ValueError("Each config must contain a 'tensor_fqn'.")
60+
61+
# Extract module information from tensor_fqn
62+
info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn)
63+
module = info_from_tensor_fqn["module"]
64+
65+
# Apply the qconfig directly to the module if it exists
66+
if module is not None:
67+
module.qconfig = QConfig(
68+
activation=PerChannelNormObserver,
69+
weight=default_placeholder_observer,
70+
) # type: ignore[assignment]
5371
torch.ao.quantization.prepare(model, inplace=True)
5472

5573
# call superclass prepare

0 commit comments

Comments
 (0)