1
- import logging
2
- import unittest
3
1
import copy
2
+ import unittest
4
3
5
4
import torch
6
- import torch .nn .functional as F
7
5
from torch import nn
8
6
from torch .testing ._internal .common_utils import TestCase
9
7
10
8
from torchao .sparsity .training import (
9
+ SemiSparseLinear ,
11
10
swap_linear_with_semi_sparse_linear ,
12
11
swap_semi_sparse_linear_with_linear ,
13
- SemiSparseLinear
14
12
)
15
13
from torchao .utils import TORCH_VERSION_AT_LEAST_2_4 , is_fbcode
16
14
15
+
17
16
class ToyModel (nn .Module ):
18
17
def __init__ (self ):
19
18
super ().__init__ ()
@@ -26,23 +25,26 @@ def forward(self, x):
26
25
x = self .linear2 (x )
27
26
return x
28
27
29
- class TestRuntimeSemiStructuredSparsity (TestCase ):
30
28
29
+ class TestRuntimeSemiStructuredSparsity (TestCase ):
31
30
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "pytorch 2.4+ feature" )
32
31
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
33
32
@unittest .skipIf (is_fbcode (), "broken in fbcode" )
34
33
@unittest .skip ("Temporarily skipping to unpin nightlies" )
35
34
def test_runtime_weight_sparsification (self ):
36
35
# need this import inside to not break 2.2 tests
37
36
from torch .sparse import SparseSemiStructuredTensorCUSPARSELT
37
+
38
38
input = torch .rand ((128 , 128 )).half ().cuda ()
39
39
grad = torch .rand ((128 , 128 )).half ().cuda ()
40
40
model = ToyModel ().half ().cuda ()
41
41
model_c = copy .deepcopy (model )
42
42
43
43
for name , mod in model .named_modules ():
44
44
if isinstance (mod , torch .nn .Linear ):
45
- sparse = SparseSemiStructuredTensorCUSPARSELT .prune_dense_static_sort (mod .weight .detach ()).to_dense ()
45
+ sparse = SparseSemiStructuredTensorCUSPARSELT .prune_dense_static_sort (
46
+ mod .weight .detach ()
47
+ ).to_dense ()
46
48
mod .weight = nn .Parameter (sparse )
47
49
48
50
dense_result = model (input )
@@ -62,8 +64,12 @@ def test_runtime_weight_sparsification(self):
62
64
sparse_result .backward (grad )
63
65
64
66
# check grad
65
- assert torch .allclose (model .linear1 .weight .grad , model_c .linear1 .weight .grad , rtol = 1e-1 , atol = 1e-1 )
66
- assert torch .allclose (model .linear2 .weight .grad , model_c .linear2 .weight .grad , rtol = 1e-1 , atol = 1e-1 )
67
+ assert torch .allclose (
68
+ model .linear1 .weight .grad , model_c .linear1 .weight .grad , rtol = 1e-1 , atol = 1e-1
69
+ )
70
+ assert torch .allclose (
71
+ model .linear2 .weight .grad , model_c .linear2 .weight .grad , rtol = 1e-1 , atol = 1e-1
72
+ )
67
73
68
74
# check that swap back works
69
75
swap_semi_sparse_linear_with_linear (model_c )
@@ -77,14 +83,17 @@ def test_runtime_weight_sparsification(self):
77
83
def test_runtime_weight_sparsification_compile (self ):
78
84
# need this import inside to not break 2.2 tests
79
85
from torch .sparse import SparseSemiStructuredTensorCUSPARSELT
86
+
80
87
input = torch .rand ((128 , 128 )).half ().cuda ()
81
88
grad = torch .rand ((128 , 128 )).half ().cuda ()
82
89
model = ToyModel ().half ().cuda ()
83
90
model_c = copy .deepcopy (model )
84
91
85
92
for name , mod in model .named_modules ():
86
93
if isinstance (mod , torch .nn .Linear ):
87
- sparse = SparseSemiStructuredTensorCUSPARSELT .prune_dense_static_sort (mod .weight .detach ()).to_dense ()
94
+ sparse = SparseSemiStructuredTensorCUSPARSELT .prune_dense_static_sort (
95
+ mod .weight .detach ()
96
+ ).to_dense ()
88
97
mod .weight = nn .Parameter (sparse )
89
98
90
99
model = torch .compile (model , fullgraph = True )
@@ -106,8 +115,12 @@ def test_runtime_weight_sparsification_compile(self):
106
115
sparse_result .backward (grad )
107
116
108
117
# check grad
109
- assert torch .allclose (model .linear1 .weight .grad , model_c .linear1 .weight .grad , rtol = 1e-1 , atol = 1e-1 )
110
- assert torch .allclose (model .linear2 .weight .grad , model_c .linear2 .weight .grad , rtol = 1e-1 , atol = 1e-1 )
118
+ assert torch .allclose (
119
+ model .linear1 .weight .grad , model_c .linear1 .weight .grad , rtol = 1e-1 , atol = 1e-1
120
+ )
121
+ assert torch .allclose (
122
+ model .linear2 .weight .grad , model_c .linear2 .weight .grad , rtol = 1e-1 , atol = 1e-1
123
+ )
111
124
112
125
# check that swap back works
113
126
swap_semi_sparse_linear_with_linear (model_c )
0 commit comments