Skip to content

Commit 83a2e28

Browse files
authored
migrate sparsify_ to configs (#1856)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent c04106c commit 83a2e28

File tree

3 files changed

+59
-25
lines changed

3 files changed

+59
-25
lines changed

test/sparsity/test_supermask.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from torch import nn
77
from torch.testing._internal import common_utils
88

9-
from torchao.sparsity import sparsify_
10-
119
logging.basicConfig(
1210
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
1311
)
@@ -30,13 +28,10 @@ def test_supermask(self, sparsity_level, blocksize):
3028
from torchao.sparsity import SupermaskLinear
3129

3230
M, N = model[0].weight.shape
33-
sparsify_(
34-
model,
35-
lambda x: SupermaskLinear.from_linear(
36-
x, sparsity_level=sparsity_level, blocksize=blocksize
37-
),
31+
model[0] = SupermaskLinear.from_linear(
32+
model[0], sparsity_level=sparsity_level, blocksize=blocksize
3833
)
39-
sparsify_(model, SupermaskLinear.to_linear)
34+
model[0] = SupermaskLinear.to_linear(model[0])
4035
weight_bsr = model[0].weight.to_sparse_bsr(blocksize=blocksize)
4136

4237
# Test correct sparsity level

torchao/sparsity/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,22 +78,22 @@ quantize_(model, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()))
7878
### 2:4 sparsity
7979

8080
```py
81-
from torchao.sparsity.sparse_api import sparsify_, semi_sparse_weight
81+
from torchao.sparsity.sparse_api import sparsify_, SemiSparseWeightConfig
8282
from torchao.dtypes import SemiSparseLayout
8383

8484
model = model.cuda()
85-
sparsify_(model, semi_sparse_weight())
85+
sparsify_(model, SemiSparseWeightConfig())
8686
```
8787

8888
### Block sparsity
8989
We offer prototype support for accelerating block sparsity with our triton kernels for bfloat16/float16 workloads.
9090

9191
```py
9292
from torchao.sparsity.sparse_api import sparsify_
93-
from torchao.sparsity import block_sparse_weight
93+
from torchao.sparsity import BlockSparseWeightConfig
9494

9595
model = model.cuda()
96-
sparsify_(model, block_sparse_weight())
96+
sparsify_(model, BlockSparseWeightConfig())
9797
```
9898

9999
# Goal

torchao/sparsity/sparse_api.py

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
1-
from functools import partial
1+
import types
2+
from dataclasses import dataclass
23
from typing import Callable, Optional
34

45
import torch
56
from torch.sparse import to_sparse_semi_structured
67

8+
from torchao.core.config import AOBaseConfig
79
from torchao.prototype.sparsity.sparsifier.weight_norm_sparsifier import (
810
WeightNormSparsifier,
911
)
1012
from torchao.quantization.quant_api import (
11-
_get_linear_subclass_inserter,
1213
_is_linear,
14+
_linear_extra_repr,
1315
_replace_with_custom_fn_if_matches_filter,
1416
)
17+
from torchao.quantization.transform_module import (
18+
_QUANTIZE_CONFIG_HANDLER,
19+
register_quantize_module_handler,
20+
)
1521
from torchao.sparsity.blocksparse import BlockSparseTensor
1622

1723

@@ -35,22 +41,53 @@ def apply_fake_sparsity(model, **kwargs):
3541
sparsifier.squash_mask()
3642

3743

38-
def block_sparse_weight(blocksize=64):
39-
return _get_linear_subclass_inserter(
40-
partial(BlockSparseTensor.from_dense, blocksize=blocksize)
41-
)
44+
@dataclass
45+
class BlockSparseWeightConfig(AOBaseConfig):
46+
blocksize: int = 64
47+
48+
49+
# for bc
50+
block_sparse_weight = BlockSparseWeightConfig
51+
52+
53+
@register_quantize_module_handler(BlockSparseWeightConfig)
54+
def _block_sparse_weight_transform(
55+
module: torch.nn.Module,
56+
config: BlockSparseWeightConfig,
57+
):
58+
blocksize = config.blocksize
59+
new_weight = BlockSparseTensor.from_dense(module.weight, blocksize)
60+
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
61+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
62+
return module
4263

4364

44-
def semi_sparse_weight():
65+
class SemiSparseWeightConfig(AOBaseConfig):
4566
"""
46-
Convert the weight of linear moduels to semi-structured (2:4) sparsity
67+
Configuration for converting the weight of linear modules to semi-structured (2:4) sparsity
4768
"""
48-
return _get_linear_subclass_inserter(to_sparse_semi_structured)
69+
70+
pass
71+
72+
73+
# for bc
74+
semi_sparse_weight = SemiSparseWeightConfig
75+
76+
77+
@register_quantize_module_handler(SemiSparseWeightConfig)
78+
def _semi_sparse_weight_transform(
79+
module: torch.nn.Module,
80+
config: SemiSparseWeightConfig,
81+
) -> torch.nn.Module:
82+
new_weight = to_sparse_semi_structured(module.weight)
83+
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
84+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
85+
return module
4986

5087

5188
def sparsify_(
5289
model: torch.nn.Module,
53-
apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor],
90+
config: AOBaseConfig,
5491
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None,
5592
) -> torch.nn.Module:
5693
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`.
@@ -63,8 +100,8 @@ def sparsify_(
63100
64101
Args:
65102
model (torch.nn.Module): input model
66-
apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (sparsified) tensor subclass instance (e.g. affine quantized tensor instance)
67-
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on the weight of the module
103+
config (AOBaseConfig): a workflow configuration object
104+
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to apply the specified workflow to this module.
68105
69106
**Example:**
70107
::
@@ -85,8 +122,10 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
85122
from torchao.dtypes import SemiSparseLayout
86123
m = quantize_(m, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout), filter_fn)
87124
"""
125+
handler = _QUANTIZE_CONFIG_HANDLER[type(config)]
88126
_replace_with_custom_fn_if_matches_filter(
89127
model,
90-
apply_tensor_subclass,
128+
handler,
91129
_is_linear if filter_fn is None else filter_fn,
130+
extra_args=(config,),
92131
)

0 commit comments

Comments
 (0)