Skip to content

Commit f2c908b

Browse files
authored
Refactor superblock code + add final benchmark / eval scripts (#691)
I tried to pull out as much shared code as possible into utils.py, now both benchmark.py and evaluate.py are singe function files. I also added a block_sparse_wieght function to the BlockSparse subclass. We should probably make this a public API before PTC, I might try to turn this into a good-first-task kind of thing. Additionally fixed a bug so FakeSparsity Parameterizations now return a state_dict, so the mask are present in the dumped file.
1 parent 4082008 commit f2c908b

File tree

9 files changed

+237
-452
lines changed

9 files changed

+237
-452
lines changed

test/sparsity/test_parametrization.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,18 +130,15 @@ def test_state_dict_preserved(self):
130130
model_load.seq[1].parametrizations["weight"].original,
131131
)
132132

133-
# Check the masks are not preserved in the state_dict
134-
# We store the state_dicts in the sparsifier, not in the model itself.
135-
# TODO: Need to find a clean way of exporting the parametrized model
136-
self.assertNotEqual(
133+
self.assertEqual(
137134
model_save.linear.parametrizations["weight"][0].mask,
138135
model_load.linear.parametrizations["weight"][0].mask,
139136
)
140-
self.assertNotEqual(
137+
self.assertEqual(
141138
model_save.seq[0].parametrizations["weight"][0].mask,
142139
model_load.seq[0].parametrizations["weight"][0].mask,
143140
)
144-
self.assertNotEqual(
141+
self.assertEqual(
145142
model_save.seq[1].parametrizations["weight"][0].mask,
146143
model_load.seq[1].parametrizations["weight"][0].mask,
147144
)

torchao/sparsity/prototype/sparsifier/utils.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,3 @@ def __init__(self, mask):
128128
def forward(self, x):
129129
assert self.mask.shape == x.shape
130130
return self.mask * x
131-
132-
def state_dict(self, *args, **kwargs):
133-
# We don't want to let the parametrizations to save the mask.
134-
# That way we make sure that the linear module doesn't store the masks
135-
# alongside their parametrizations.
136-
return {}

torchao/sparsity/prototype/superblock/README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# SuperBlock
22

3-
SuperBlock combines two techniques for efficient neural network training and inference: Supermask and Block Compressed Sparse Row (BSR).
3+
SuperBlock combines two techniques for efficient neural network training and inference: Supermask and Block Compressed Sparse Row (BSR).
44
The techniques are described in this [blog post](https://pytorch.org/blog/speeding-up-vits/).
55

66
### Supermask
77
[Supermask](https://arxiv.org/abs/2207.00670) is a technique for applying structured sparsity to neural networks using a learned mask. It works by learning a continuous mask (scores) that is applied element-wise to the weights of a neural network layer. The mask scores are learned separately from the weights and are thresholded based on a target sparsity level to obtain a binary mask. The mask determines which weigths are kept and which are pruned, and is learned during training.
88

9-
During inference, the binary mask is applied element-wise to the weights, pruning the weights that correspond to a 0 in the mask, resulting in a sparse network that can be efficiently computed.
9+
During inference, the binary mask is applied element-wise to the weights, pruning the weights that correspond to a 0 in the mask, resulting in a sparse network that can be efficiently computed.
1010

1111
### Block compressed Sparse Row Format (BSR)
1212
[The BSR format](https://pytorch.org/docs/main/sparse.html#sparse-bsr-tensor) is a sparse matrix representation that stores dense sub-blocks of non-zero elements instead of individual non-zero elements. The matrix is divided into equal-sized blocks, and only the non-zero blocks are stored.
@@ -105,7 +105,7 @@ torchrun --nproc_per_node=8 train.py\
105105
--model vit_b_16 --epochs 300 --batch-size 512 --opt adamw --lr 0.003 --wd 0.3\
106106
--lr-scheduler cosineannealinglr --lr-warmup-method linear --lr-warmup-epochs 30\
107107
--lr-warmup-decay 0.033 --amp --label-smoothing 0.11 --mixup-alpha 0.2 --auto-augment ra\
108-
--clip-grad-norm 1 --ra-sampler --cutmix-alpha 1.0 --model-ema\
108+
--clip-grad-norm 1 --ra-sampler --cutmix-alpha 1.0 --model-ema\
109109
--sparsity-linear 0.9 --sp-linear-tile-size 32
110110
```
111111
Through this command, we are training a `vit_b_16` with 90% sparsity to linear layers using 32x32 tiles.
@@ -124,7 +124,7 @@ NGPUS=1 # put number of available GPUS here
124124
125125
* Offline sparsification with BSR:
126126
```
127-
torchrun --nproc_per_node=${NGPUS} evaluate.py --model vit_b_16 --batch-size 256 --sparsity-linear 0.9 --sp-linear-tile-size 32 --weights-path ${MODEL_PATH} --data-path ${IMAGENET_PATH} --sparsify-weights --bsr 32
127+
python evaluate.py --model vit_b_16 --batch-size 256 --sparsity-linear 0.9 --sp-linear-tile-size 32 --weights-path ${MODEL_PATH} --data-path ${IMAGENET_PATH} --sparsity bsr --bsr 64
128128
```
129129
This command applies 90% sparsity to linear layers using 32x32 tiles, loads the model weights from ${MODEL_PATH}, loads the ImageNet validation set located at the specified path, applies offline sparsification to the weights, and converts the sparse weights to BSR format with a block size of 32. It is recommended to set `--bsr` the same as tile size.
130130
@@ -184,7 +184,7 @@ python benchmark.py --model vit_b_16 \
184184
--batch-size 256 \
185185
--sparsity-linear ${SPARSITY} \
186186
--sp-linear-tile-size ${BLOCK_SIZE} \
187-
--sparsify-weights \
187+
--sparsity bsr\
188188
--bsr ${BLOCK_SIZE} \
189189
--weights-path ./checkpoints/sp${SPARSITY}-ts${BLOCK_SIZE}.pth \
190190
> /dev/null
@@ -197,7 +197,7 @@ Result:
197197
### Evaluate:
198198
8 x A100 GPUs:
199199
```
200-
torchrun --nproc_per_node=8 evaluate.py --model vit_b_16 --batch-size 256 --sparsity-linear ${SPARSITY} --sp-linear-tile-size ${BLOCK_SIZE} --bsr ${BLOCK_SIZE} --sparsify-weights --weights-path checkpoints/sp${SPARSITY}-ts${BLOCK_SIZE}.pth --data-path ${IMAGENET_PATH}
200+
torchrun --nproc_per_node=8 evaluate.py --model vit_b_16 --batch-size 256 --sparsity-linear ${SPARSITY} --sp-linear-tile-size ${BLOCK_SIZE} --bsr ${BLOCK_SIZE} --sparsity bsr --weights-path checkpoints/sp${SPARSITY}-ts${BLOCK_SIZE}.pth --data-path ${IMAGENET_PATH}
201201
```
202202
Result:
203203
```
@@ -207,7 +207,7 @@ Test: Acc@1 77.644 Acc@5 93.554
207207
208208
1 x A100 GPUs:
209209
```
210-
torchrun --nproc_per_node=1 evaluate.py --model vit_b_16 --batch-size 256 --sparsity-linear ${SPARSITY} --sp-linear-tile-size ${BLOCK_SIZE} --bsr ${BLOCK_SIZE} --sparsify-weights --weights-path checkpoints/sp${SPARSITY}-ts${BLOCK_SIZE}.pth --data-path ${IMAGENET_PATH}
210+
torchrun --nproc_per_node=1 evaluate.py --model vit_b_16 --batch-size 256 --sparsity-linear ${SPARSITY} --sp-linear-tile-size ${BLOCK_SIZE} --bsr ${BLOCK_SIZE} --sparsity bsr--weights-path checkpoints/sp${SPARSITY}-ts${BLOCK_SIZE}.pth --data-path ${IMAGENET_PATH}
211211
```
212212
Result:
213213
```

torchao/sparsity/prototype/superblock/benchmark.py

Lines changed: 34 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -13,36 +13,10 @@
1313
import utils
1414
from torch import nn
1515
from torch.sparse._triton_ops_meta import optimize_bsr_dense_addmm
16+
from torchao.sparsity.prototype.superblock.utils import accelerate_with_sparsity, simulate_sparsity
17+
from torchao.utils import benchmark_model, profiler_runner
1618

17-
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
18-
from supermask import apply_supermask, SupermaskLinear
19-
from blocksparse import BlockSparseTensor
20-
from utils import benchmark_inference
21-
22-
23-
def apply_sparsity(model):
24-
for name, module in model.named_modules():
25-
if isinstance(module, SupermaskLinear) and "mlp" in name:
26-
module.sparsify_offline()
27-
28-
29-
def apply_bsr(model, blocksize):
30-
for name, module in model.named_modules():
31-
if isinstance(module, torch.nn.Linear) and "mlp" in name:
32-
try:
33-
module.weight = torch.nn.Parameter(BlockSparseTensor.from_dense(module.weight.data, blocksize))
34-
print(f"Converted {name} to bsr format.")
35-
except ValueError as e:
36-
print(f"Unable to convert weight of {name} to bsr format: {e}")
37-
38-
39-
def verify_sparsity(model):
40-
for name, module in model.named_modules():
41-
if isinstance(module, nn.Linear):
42-
total_weights = module.weight.numel()
43-
sparse_weights = (module.weight == 0).sum().item()
44-
sparsity_percentage = (sparse_weights / total_weights) * 100
45-
print(f"Sparsity verified in layer {name}: {sparsity_percentage:.2f}%")
19+
torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False
4620

4721
@torch.inference_mode
4822
def main(args):
@@ -54,36 +28,26 @@ def main(args):
5428
torch.backends.cudnn.deterministic = True
5529
num_classes = 1000
5630

57-
dtype = None
58-
if args.bfloat16:
59-
print("Using bfloat16")
60-
dtype = torch.bfloat16
61-
elif args.float16:
62-
print("Using float16")
63-
dtype = torch.float16
31+
dtype = getattr(torch, args.dtype)
32+
print(f"Using dtype: {dtype}")
6433

34+
# BSR kernel tuning
6535
if args.bsr and args.tune_kernel_params:
6636
print("Tuning kernel params")
67-
assert args.model == "vit_b_16", "--tune-kernel-params only supported for vit-b-16!"
68-
optimize_bsr_dense_addmm(3072, 768, 50432, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
69-
optimize_bsr_dense_addmm(768, 3072, 50432, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
37+
if args.model == "vit_b_16":
38+
optimize_bsr_dense_addmm(3072, 768, 50432, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
39+
optimize_bsr_dense_addmm(768, 3072, 50432, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
40+
elif args.model == "vit_h_14":
41+
optimize_bsr_dense_addmm(5120, 1280, 65792, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
42+
optimize_bsr_dense_addmm(1280, 5120, 65792, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
43+
else:
44+
raise NotImplementedError("Tuning kernel params for this model is not supported yet.")
7045

7146
print("Creating model")
7247
model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
7348

74-
apply_supermask(
75-
model,
76-
linear_sparsity=args.sparsity_linear,
77-
linear_sp_tilesize=args.sp_linear_tile_size,
78-
conv1x1_sparsity=args.sparsity_conv1x1,
79-
conv1x1_sp_tilesize=args.sp_conv1x1_tile_size,
80-
conv_sparsity=args.sparsity_conv,
81-
conv_sp_tilesize=args.sp_conv_tile_size,
82-
skip_last_layer_sparsity=args.skip_last_layer_sparsity,
83-
skip_first_transformer_sparsity=args.skip_first_transformer_sparsity,
84-
device=device,
85-
verbose=False,
86-
)
49+
# Fake sparsity necessary for BSR
50+
simulate_sparsity(model, args)
8751

8852
if args.weights_path:
8953
try:
@@ -93,33 +57,24 @@ def main(args):
9357
except FileNotFoundError:
9458
raise FileNotFoundError(f"No checkpoint found at {args.weights_path}.")
9559

96-
model.to(device)
97-
98-
if args.sparsify_weights:
99-
apply_sparsity(model)
100-
verify_sparsity(model)
101-
102-
# verify correctness
103-
# output1 = model(input)
104-
# assert torch.allclose(output0, output1), "Output of model before and after weight sparsification should be equal"
60+
model.to(device).to(dtype)
10561

106-
if dtype:
107-
model = model.to(dtype)
62+
# Fake sparsity necessary for BSR
63+
accelerate_with_sparsity(model, args)
10864

109-
if args.bsr:
110-
if not args.sparsify_weights:
111-
raise ValueError("--bsr can only be used when --sparsify_weights is also specified.")
112-
apply_bsr(model, blocksize=args.bsr)
65+
# compile
66+
model = torch.compile(model, mode='max-autotune', fullgraph=True)
11367

114-
# verify correctness
115-
# output2 = model(input)
116-
# assert torch.allclose(output2, output1), "Output of model before and after changing format to BSR should be equal"
68+
# define image
69+
image = torch.randn(args.batch_size, 3, args.val_crop_size, args.val_crop_size, dtype=dtype, device=device)
11770

118-
model = torch.compile(model, mode='max-autotune')
71+
# warmup
72+
benchmark_model(model, 10, args=(image,))
73+
if args.profile:
74+
return profiler_runner("test.json.gz", benchmark_model, model, 10, (image,))
75+
else:
76+
return benchmark_model(model, 100, args=(image,))
11977

120-
image = torch.empty(args.batch_size, 3, args.val_crop_size, args.val_crop_size, dtype=dtype, device=device)
121-
122-
return benchmark_inference(10, 100, model, image)
12378

12479

12580
def get_args_parser(add_help=True):
@@ -131,15 +86,13 @@ def get_args_parser(add_help=True):
13186
parser.add_argument(
13287
"-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
13388
)
134-
135-
# Mixed precision training parameters
13689
parser.add_argument(
13790
"--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
13891
)
13992
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
14093
parser.add_argument("--weights-path", type=str, help="path of pretrained weights to load")
141-
14294
# NOTE: sparsity args
95+
parser.add_argument("--sparsity", choices=["bsr", "semi_structured"], default=None, help='weight sparsification to apply')
14396
parser.add_argument("--sparsity-linear", type=float, default=0.0)
14497
parser.add_argument("--sp-linear-tile-size", type=int, default=1)
14598
parser.add_argument("--sparsity-conv1x1", type=float, default=0.0)
@@ -148,11 +101,12 @@ def get_args_parser(add_help=True):
148101
parser.add_argument("--sp-conv-tile-size", type=int, default=1)
149102
parser.add_argument("--skip-last-layer-sparsity", action="store_true", help="Skip applying sparsity to the last linear layer (for vit only)")
150103
parser.add_argument("--skip-first-transformer-sparsity", action="store_true", help="Skip applying sparsity to the first transformer layer (for vit only)")
151-
parser.add_argument('--sparsify-weights', action='store_true', help='Apply weight sparsification in evaluation mode')
152104
parser.add_argument('--bsr', type=int, nargs='?', const=256, default=None, help='Convert sparsified weights to BSR format with optional block size (default: 256)')
153-
parser.add_argument("--bfloat16", action="store_true", help="Use bfloat16")
105+
parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], help="data type", default="bfloat16")
154106
parser.add_argument("--float16", action="store_true", help="Use float16")
155107
parser.add_argument("--tune-kernel-params", action="store_true", help="Tune kernel params")
108+
parser.add_argument("--profile", action="store_true", help="Profile the run and dump Prefetto trace")
109+
parser.add_argument("--quantization", action="store_true", help="Profile the run and dump Prefetto trace")
156110

157111
return parser
158112

@@ -161,3 +115,4 @@ def get_args_parser(add_help=True):
161115
args = get_args_parser().parse_args()
162116
result = main(args)
163117
print(f"{result:.3f} ms", file=sys.stderr)
118+
print(f"{1000/result:.3f} img/s")

torchao/sparsity/prototype/superblock/blocksparse.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import partial
2+
13
import torch
24
from typing import Optional, Tuple, List, Dict, Any, Callable
35
from torch.utils._python_dispatch import return_and_correct_aliasing
@@ -6,6 +8,8 @@
68
_dispatch__torch_function__,
79
_dispatch__torch_dispatch__,
810
)
11+
from torchao.quantization.quant_api import _get_linear_subclass_inserter
12+
913
aten = torch.ops.aten
1014

1115
# bsr wrapper custom op
@@ -136,3 +140,6 @@ def block_sparse_linear(func, types, args, kwargs):
136140
w.col_indices(),
137141
w.values(),
138142
w.shape[0], w.shape[1], bias)
143+
144+
def block_sparse_weight(blocksize=64):
145+
return _get_linear_subclass_inserter(partial(BlockSparseTensor.from_dense, blocksize=blocksize))

0 commit comments

Comments
 (0)