Skip to content

Commit a558f7e

Browse files
authored
Lint fixes test sparsity (#1360)
1 parent cfabd6d commit a558f7e

File tree

4 files changed

+75
-32
lines changed

4 files changed

+75
-32
lines changed

ruff.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ include = [
1111
"test/float8/**/*.py",
1212
"test/quantization/**/*.py",
1313
"test/dtypes/**/*.py",
14+
"test/sparsity/**/*.py",
1415
"test/prototype/low_bit_optim/**.py",
1516
"torchao/utils.py",
1617

test/sparsity/test_fast_sparse_training.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
1-
import logging
2-
import unittest
31
import copy
2+
import unittest
43

54
import torch
6-
import torch.nn.functional as F
75
from torch import nn
86
from torch.testing._internal.common_utils import TestCase
97

108
from torchao.sparsity.training import (
9+
SemiSparseLinear,
1110
swap_linear_with_semi_sparse_linear,
1211
swap_semi_sparse_linear_with_linear,
13-
SemiSparseLinear
1412
)
1513
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_fbcode
1614

15+
1716
class ToyModel(nn.Module):
1817
def __init__(self):
1918
super().__init__()
@@ -26,23 +25,26 @@ def forward(self, x):
2625
x = self.linear2(x)
2726
return x
2827

29-
class TestRuntimeSemiStructuredSparsity(TestCase):
3028

29+
class TestRuntimeSemiStructuredSparsity(TestCase):
3130
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature")
3231
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
3332
@unittest.skipIf(is_fbcode(), "broken in fbcode")
3433
@unittest.skip("Temporarily skipping to unpin nightlies")
3534
def test_runtime_weight_sparsification(self):
3635
# need this import inside to not break 2.2 tests
3736
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
37+
3838
input = torch.rand((128, 128)).half().cuda()
3939
grad = torch.rand((128, 128)).half().cuda()
4040
model = ToyModel().half().cuda()
4141
model_c = copy.deepcopy(model)
4242

4343
for name, mod in model.named_modules():
4444
if isinstance(mod, torch.nn.Linear):
45-
sparse = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(mod.weight.detach()).to_dense()
45+
sparse = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(
46+
mod.weight.detach()
47+
).to_dense()
4648
mod.weight = nn.Parameter(sparse)
4749

4850
dense_result = model(input)
@@ -62,8 +64,12 @@ def test_runtime_weight_sparsification(self):
6264
sparse_result.backward(grad)
6365

6466
# check grad
65-
assert torch.allclose(model.linear1.weight.grad, model_c.linear1.weight.grad, rtol=1e-1, atol=1e-1)
66-
assert torch.allclose(model.linear2.weight.grad, model_c.linear2.weight.grad, rtol=1e-1, atol=1e-1)
67+
assert torch.allclose(
68+
model.linear1.weight.grad, model_c.linear1.weight.grad, rtol=1e-1, atol=1e-1
69+
)
70+
assert torch.allclose(
71+
model.linear2.weight.grad, model_c.linear2.weight.grad, rtol=1e-1, atol=1e-1
72+
)
6773

6874
# check that swap back works
6975
swap_semi_sparse_linear_with_linear(model_c)
@@ -77,14 +83,17 @@ def test_runtime_weight_sparsification(self):
7783
def test_runtime_weight_sparsification_compile(self):
7884
# need this import inside to not break 2.2 tests
7985
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
86+
8087
input = torch.rand((128, 128)).half().cuda()
8188
grad = torch.rand((128, 128)).half().cuda()
8289
model = ToyModel().half().cuda()
8390
model_c = copy.deepcopy(model)
8491

8592
for name, mod in model.named_modules():
8693
if isinstance(mod, torch.nn.Linear):
87-
sparse = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(mod.weight.detach()).to_dense()
94+
sparse = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(
95+
mod.weight.detach()
96+
).to_dense()
8897
mod.weight = nn.Parameter(sparse)
8998

9099
model = torch.compile(model, fullgraph=True)
@@ -106,8 +115,12 @@ def test_runtime_weight_sparsification_compile(self):
106115
sparse_result.backward(grad)
107116

108117
# check grad
109-
assert torch.allclose(model.linear1.weight.grad, model_c.linear1.weight.grad, rtol=1e-1, atol=1e-1)
110-
assert torch.allclose(model.linear2.weight.grad, model_c.linear2.weight.grad, rtol=1e-1, atol=1e-1)
118+
assert torch.allclose(
119+
model.linear1.weight.grad, model_c.linear1.weight.grad, rtol=1e-1, atol=1e-1
120+
)
121+
assert torch.allclose(
122+
model.linear2.weight.grad, model_c.linear2.weight.grad, rtol=1e-1, atol=1e-1
123+
)
111124

112125
# check that swap back works
113126
swap_semi_sparse_linear_with_linear(model_c)

test/sparsity/test_marlin.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,24 @@
1-
import torch
21
import copy
3-
import pytest
42

3+
import pytest
4+
import torch
55
from torch import nn
66
from torch.testing._internal.common_utils import TestCase, run_tests
7-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
7+
88
from torchao.dtypes import MarlinSparseLayout
9-
from torchao.sparsity.sparse_api import apply_fake_sparsity
109
from torchao.quantization.quant_api import int4_weight_only, quantize_
11-
from torchao.sparsity.marlin import (
12-
pack_to_marlin_24,
13-
unpack_from_marlin_24,
14-
inject_24
15-
)
1610
from torchao.quantization.quant_primitives import (
11+
MappingType,
12+
ZeroPointDomain,
1713
choose_qparams_affine,
1814
quantize_affine,
19-
ZeroPointDomain,
20-
MappingType,
2115
)
16+
from torchao.sparsity.marlin import inject_24, pack_to_marlin_24, unpack_from_marlin_24
17+
from torchao.sparsity.sparse_api import apply_fake_sparsity
18+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
2219

2320

2421
class SparseMarlin24(TestCase):
25-
2622
def setUp(self):
2723
super().setUp()
2824
torch.manual_seed(0)
@@ -53,7 +49,9 @@ def test_quant_sparse_marlin_layout_eager(self):
5349
quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout()))
5450
sparse_result = self.model(self.input)
5551

56-
assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"
52+
assert torch.allclose(
53+
dense_result, sparse_result, atol=3e-1
54+
), "Results are not close"
5755

5856
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+")
5957
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
@@ -71,7 +69,9 @@ def test_quant_sparse_marlin_layout_compile(self):
7169
self.model.forward = torch.compile(self.model.forward, fullgraph=True)
7270
sparse_result = self.model(self.input)
7371

74-
assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"
72+
assert torch.allclose(
73+
dense_result, sparse_result, atol=3e-1
74+
), "Results are not close"
7575

7676
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
7777
def test_pack_unpack_equivalence(self):
@@ -94,9 +94,30 @@ def test_pack_unpack_equivalence(self):
9494
# Inject 2:4 sparsity mask
9595
w_24, _ = inject_24(w, *w.shape)
9696

97-
# Quantize weights
98-
scales, zeros = choose_qparams_affine(w_24, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
99-
w_q_24 = quantize_affine(w_24, block_size, scales, zeros, target_dtype, quant_min, quant_max, zero_point_domain)
97+
# Quantize weights
98+
scales, zeros = choose_qparams_affine(
99+
w_24,
100+
mapping_type,
101+
block_size,
102+
target_dtype,
103+
quant_min,
104+
quant_max,
105+
eps,
106+
scale_dtype,
107+
zero_point_dtype,
108+
preserve_zero,
109+
zero_point_domain,
110+
)
111+
w_q_24 = quantize_affine(
112+
w_24,
113+
block_size,
114+
scales,
115+
zeros,
116+
target_dtype,
117+
quant_min,
118+
quant_max,
119+
zero_point_domain,
120+
)
100121
scales = scales.reshape(-1, w_q_24.shape[1])
101122

102123
# Test pack/unpack equivalence
@@ -107,8 +128,12 @@ def test_pack_unpack_equivalence(self):
107128
q_w_comp, packed_scales, meta, shape, group_size, num_bits
108129
)
109130

110-
assert torch.equal(w_q_24, unpacked_q_w), "Unpacked weights do not match original weights"
111-
assert torch.equal(scales, unpacked_scales), "Unpacked scales do not match original scales"
131+
assert torch.equal(
132+
w_q_24, unpacked_q_w
133+
), "Unpacked weights do not match original weights"
134+
assert torch.equal(
135+
scales, unpacked_scales
136+
), "Unpacked scales do not match original scales"
112137

113138

114139
if __name__ == "__main__":

test/sparsity/test_wanda.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33

44
import torch
55
from torch import nn
6-
from torchao.sparsity import WandaSparsifier
76
from torch.ao.pruning import FakeSparsity
87
from torch.nn.utils.parametrize import is_parametrized
98
from torch.testing._internal.common_pruning import SimpleLinear
109
from torch.testing._internal.common_utils import TestCase
1110

11+
from torchao.sparsity import WandaSparsifier
12+
1213
logging.basicConfig(
1314
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
1415
)
@@ -29,7 +30,9 @@ def test_prepare(self):
2930
assert hasattr(module.parametrizations["weight"][0], "mask")
3031
# Check parametrization exists and is correct
3132
assert is_parametrized(module, "weight")
32-
assert type(module.parametrizations.weight[0]) == FakeSparsity
33+
assert isinstance(
34+
module.parametrizations.weight[0], FakeSparsity
35+
), "FakeSparsity not found"
3336
# check activation observer is present
3437
assert hasattr(module, "activation_post_process")
3538

@@ -110,5 +113,6 @@ def test_two_layer_mlp_unstructured(self):
110113

111114
sparsifier.squash_mask()
112115

116+
113117
if __name__ == "__main__":
114118
unittest.main()

0 commit comments

Comments
 (0)