Skip to content

Commit a520746

Browse files
authored
pytorch_pruner-bugfix-plus-examples (#1250)
1 parent 0791433 commit a520746

File tree

15 files changed

+2707
-2
lines changed

15 files changed

+2707
-2
lines changed

examples/README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,20 @@ Intel® Neural Compressor validated examples with multiple compression technique
565565
<td>Pattern Lock</td>
566566
<td><a href="./pytorch/nlp/huggingface_models/text-classification/pruning/pattern_lock/eager">eager</a></td>
567567
</tr>
568+
<tr>
569+
<td>Bert-mini</td>
570+
<td>Natural Language Processing (text classification)</td>
571+
<td>Structured</td>
572+
<td>Snip-momentum</td>
573+
<td><a href="./pytorch/nlp/huggingface_models/text-classification/pruning/pytorch_pruner/eager">eager</a></td>
574+
</tr>
575+
<tr>
576+
<td>Bert-mini</td>
577+
<td>Natural Language Processing (question answering)</td>
578+
<td>Structured</td>
579+
<td>Snip-momentum</td>
580+
<td><a href="./pytorch/nlp/huggingface_models/question-answering/pruning/pytorch_pruner/eager">eager</a></td>
581+
</tr>
568582
</tbody>
569583
</table>
570584

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Pytorch Pruner
2+
## Intro
3+
[**Pytorch Pruner**](https://github.com/intel/neural-compressor/tree/master/neural_compressor/experimental/pytorch_pruner) is an INC build-in API which supports a wide range of pruning algorithms, patterns as well as pruning schedulers. Features below are currently supported:
4+
> algorithms: magnitude, snip, snip-momentum\
5+
> patterns: NxM, N:M\
6+
> pruning schedulers: iterative pruning scheduler, oneshot pruning scheduler.
7+
8+
## Usage
9+
### Write a config yaml file
10+
Pytorch pruner is developed based on [pruning](https://github.com/intel/neural-compressor/blob/master/neural_compressor/experimental/pruning.py), therefore most usages are identical. Our API reads in a yaml configuration file to define a Pruning object. Here is an bert-mini example of it:
11+
```yaml
12+
version: 1.0
13+
14+
model:
15+
name: "bert-mini"
16+
framework: "pytorch"
17+
18+
pruning:
19+
approach:
20+
weight_compression_pytorch:
21+
# Global settings
22+
# if start step equals to end step, oneshot pruning scheduler is enabled. Otherwise the API automatically implements iterative pruning scheduler.
23+
start_step: 0 # step which pruning process begins
24+
end_step: 0 # step which pruning process ends
25+
not_to_prune_names: ["classifier", "pooler", ".*embeddings*"] # a global announcement of layers which you do not wish to prune.
26+
prune_layer_type: ["Linear"] # the module type which you want to prune (Linear, Conv2d, etc.)
27+
target_sparsity: 0.9 # the sparsity you want the model to be pruned.
28+
max_sparsity_ratio_per_layer: 0.98 # the sparsity ratio's maximum which one layer can reach.
29+
30+
pruners: # below each "Pruner" defines a pruning process for a group of layers. This enables us to apply different pruning methods for different layers in one model.
31+
# Local settings
32+
- !Pruner
33+
exclude_names: [".*query", ".*key", ".*value"] # list of regular expressions, containing the layer names you wish not to be included in this pruner
34+
pattern: "1x1" # pattern type, we support "NxM" and "N:M"
35+
update_frequency_on_step: 100 # if use iterative pruning scheduler, this define the pruning frequency.
36+
prune_domain: "global" # one in ["global", "local"], refers to the score map is computed out of entire parameters or its corresponding layer's weight.
37+
prune_type: "snip_momentum" # pruning algorithms, refer to pytorch_pruner/pruner.py
38+
sparsity_decay_type: "exp" # ["linear", "cos", "exp", "cube"] ways to determine the target sparsity during iterative pruning.
39+
- !Pruner
40+
exclude_names: [".*output", ".*intermediate"]
41+
pattern: "4x1"
42+
update_frequency_on_step: 100
43+
prune_domain: "global"
44+
prune_type: "snip_momentum"
45+
sparsity_decay_type: "exp"
46+
```
47+
Please be awared that when the keywords appear in both global and local settings, we select the **local** settings as priority.
48+
### Coding template:
49+
With a settled config file, we provide a template for implementing pytorch_pruner API:
50+
```python
51+
model = Model()
52+
criterion = Criterion()
53+
optimizer = Optimizer()
54+
args = Args()
55+
56+
from neural_compressor.experimental.pytorch_pruner.pruning import Pruning
57+
58+
pruner = Pruning("path/to/your/config.yaml")
59+
if args.do_prune:
60+
pruner.update_items_for_all_pruners(start_step=int(args.sparsity_warm_epochs * num_iterations), end_step=int(total_iterations)) ##iterative
61+
else:
62+
pruner.update_items_for_all_pruners(start_step=total_iterations+1, end_step=total_iterations+1) ## remove the pruner
63+
pruner.model = model
64+
pruner.on_train_begin()
65+
for epoch in range(epochs):
66+
model.train()
67+
for step, batch in enumerate(train_dataloader):
68+
pruner.on_step_begin(step)
69+
output = model(**batch)
70+
loss = output.loss
71+
loss.backward()
72+
pruner.on_before_optimizer_step()
73+
optimizer.step()
74+
pruner.on_after_optimizer_step()
75+
optimizer.zero_grad()
76+
77+
model.eval()
78+
for step, batch in enumerate(val_dataloader):
79+
...
80+
```
81+
For more usage, please refer to our example codes below.
82+
83+
## Examples
84+
we have provided several pruning examples, which are trained on different datasets/tasks, use different sparsity patterns, etc. We are working on sharing our sparse models on HuggingFace.
85+
### [SQuAD](https://github.com/intel/neural-compressor/tree/master/examples/pytorch/nlp/huggingface_models/question-answering/pruning)
86+
We can train a sparse model with NxM (2:4) pattern:
87+
```
88+
python3 ./run_qa_no_trainer.py \
89+
--model_name_or_path "/path/to/dense_finetuned_model/" \
90+
--pruning_config "./bert_mini_2:4.yaml" \
91+
--dataset_name "squad" \
92+
--max_seq_length "384" \
93+
--doc_stride "128" \
94+
--per_device_train_batch_size "8" \
95+
--weight_decay "1e-7" \
96+
--learning_rate "1e-4" \
97+
--num_train_epochs 10 \
98+
--teacher_model_name_or_path "/path/to/dense_finetuned_model/" \
99+
--distill_loss_weight "8.0"
100+
```
101+
We can also choose 4x1 as our pruning pattern:
102+
```
103+
python ./run_qa_no_trainer.py \
104+
--model_name_or_path "/path/to/dense_finetuned_model/" \
105+
--pruning_config "./bert_mini_4x1.yaml" \
106+
--dataset_name "squad" \
107+
--max_seq_length "384" \
108+
--doc_stride "128" \
109+
--per_device_train_batch_size "16" \
110+
--per_device_eval_batch_size "16" \
111+
--num_warmup_steps "1000" \
112+
--do_prune \
113+
--cooldown_epochs 5 \
114+
--learning_rate "4.5e-4" \
115+
--num_train_epochs 10 \
116+
--weight_decay "1e-7" \
117+
--output_dir "pruned_squad_bert-mini" \
118+
--teacher_model_name_or_path "/path/to/dense_finetuned_model/" \
119+
--distill_loss_weight "4.5"
120+
```
121+
Dense model training is also supported as following (by setting --do_prune to False):
122+
```
123+
python \
124+
./run_qa_no_trainer.py \
125+
--model_name_or_path "prajjwal1/bert-mini" \
126+
--pruning_config "./bert_mini_4x1.yaml" \
127+
--dataset_name "squad" \
128+
--max_seq_length "384" \
129+
--doc_stride "128" \
130+
--per_device_train_batch_size "8" \
131+
--per_device_eval_batch_size "16" \
132+
--num_warmup_steps "1000" \
133+
--learning_rate "5e-5" \
134+
--num_train_epochs 5 \
135+
--output_dir "./output_bert-mini"
136+
```
137+
### Results
138+
| Model | Dataset | Sparsity pattern |Pruning method |Element-wise/matmul, Gemm, conv ratio | Init model | Dense F1 (mean/max)| Sparse F1 (mean/max)| Relative drop|
139+
| :----: | :----: | :----: | :----: |:----: |:----:| :----: | :----: | :----: |
140+
| Bert-Mini | SQuAD | 4x1 | Snip-momentum |0.7993 | Dense & Finetuned | 0.7662/0.7687 | 0.7617/0.7627 | -0.78% |
141+
| Bert-Mini | SQuAD | 2:4 | Snip-momentum |0.4795 | Dense & Finetuned | 0.7662/0.7687 | 0.7645/0.7685 | -0.02% |
142+
143+
## References
144+
* [SNIP: Single-shot Network Pruning based on Connection Sensitivity](https://arxiv.org/abs/1810.02340)
145+
* [Knowledge Distillation with the Reused Teacher Classifier](https://arxiv.org/abs/2203.14001)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
version: 1.0
2+
3+
model:
4+
name: "bert-mini"
5+
framework: "pytorch"
6+
7+
pruning:
8+
approach:
9+
weight_compression_pytorch:
10+
start_step: 2
11+
end_step: 2
12+
not_to_prune_names: ["qa_outputs", "pooler", ".*embeddings*", "layer.3.attention.output.dense"]
13+
prune_layer_type: ["Linear"]
14+
target_sparsity: 0.5
15+
update_frequency_on_step: 1000
16+
max_sparsity_ratio_per_layer: 0.98
17+
prune_domain: "global"
18+
19+
sparsity_decay_type: "exp"
20+
pruners:
21+
- !Pruner
22+
pattern: "2:4"
23+
target_sparsity: 0.5
24+
update_frequency_on_step: 1000
25+
prune_domain: "global"
26+
prune_type: "snip_momentum"
27+
sparsity_decay_type: "exp"
28+
29+
30+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
version: 1.0
2+
3+
model:
4+
name: "bert-mini"
5+
framework: "pytorch"
6+
7+
pruning:
8+
approach:
9+
weight_compression_pytorch:
10+
start_step: 0
11+
end_step: 0
12+
not_to_prune_names: ["qa_outputs", "pooler", ".*embeddings*", "layer.3.attention.output.dense"]
13+
prune_layer_type: ["Linear"]
14+
target_sparsity: 0.8
15+
update_frequency_on_step: 1000
16+
max_sparsity_ratio_per_layer: 0.98
17+
prune_domain: "global"
18+
19+
sparsity_decay_type: "exp"
20+
pruners:
21+
- !Pruner
22+
pattern: "oc_pattern_4x1"
23+
target_sparsity: 0.8
24+
update_frequency_on_step: 1000
25+
prune_domain: "global"
26+
prune_type: "snip_momentum"
27+
sparsity_decay_type: "exp"

0 commit comments

Comments
 (0)