Skip to content

Commit 2017c06

Browse files
committed
[Docs] Add KD algo: DFND
1 parent 90c7af1 commit 2017c06

File tree

9 files changed

+485
-4
lines changed

9 files changed

+485
-4
lines changed

configs/distill/mmcls/dfnd/README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Learning Student Networks in the Wild (DFND)
2+
3+
> [Learning Student Networks in the Wild](https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Learning_Student_Networks_in_the_Wild_CVPR_2021_paper.pdf)
4+
5+
<!-- [ALGORITHM] -->
6+
7+
## Abstract
8+
9+
Data-free learning for student networks is a new paradigm for solving users’ anxiety caused by the privacy problem of using original training data. Since the architectures of modern convolutional neural networks (CNNs) are compact and sophisticated, the alternative images or meta-data generated from the teacher network are often broken. Thus, the student network cannot achieve the comparable performance to that of the pre-trained teacher network especially on the large-scale image dataset. Different to previous works, we present to maximally utilize the massive available unlabeled data in the wild. Specifically, we first thoroughly analyze the output differences between teacher and student network on the original data and develop a data collection method. Then, a noisy knowledge distillation algorithm is proposed for achieving the performance of the student network. In practice, an adaptation matrix is learned with the student network for correcting the label noise produced by the teacher network on the collected unlabeled images. The effectiveness of our DFND (DataFree Noisy Distillation) method is then verified on several benchmarks to demonstrate its superiority over state-of-theart data-free distillation methods. Experiments on various datasets demonstrate that the student networks learned by the proposed method can achieve comparable performance with those using the original dataset.
10+
11+
<img width="910" alt="pipeline" src="./dfnd.PNG">
12+
13+
## Results and models
14+
15+
### Classification
16+
17+
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | |
18+
| :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :---------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
19+
| backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 94.78 | 95.34 | 94.82 | [config](./dfnd_logits_resnet34_resnet18_8xb32_cifar10.py) | [student](https://drive.google.com/file/d/1_MekfTkCsEl68meWPqtdNZIxdJO2R2Eb/view?usp=drive_link) |
20+
21+
## Citation
22+
23+
```latex
24+
@inproceedings{chen2021learning,
25+
title={Learning student networks in the wild},
26+
author={Chen, Hanting and Guo, Tianyu and Xu, Chang and Li, Wenshuo and Xu, Chunjing and Xu, Chao and Wang, Yunhe},
27+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
28+
pages={6428--6437},
29+
year={2021}
30+
}
31+
```

configs/distill/mmcls/dfnd/dfnd.PNG

644 KB
Loading
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
_base_ = ['mmcls::_base_/default_runtime.py']
2+
3+
# optimizer
4+
optim_wrapper = dict(
5+
optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001))
6+
# learning policy
7+
param_scheduler = dict(
8+
type='MultiStepLR', by_epoch=True, milestones=[320, 640], gamma=0.1)
9+
10+
# train, val, test setting
11+
train_cfg = dict(by_epoch=True, max_epochs=800, val_interval=1)
12+
test_cfg = dict()
13+
14+
# NOTE: `auto_scale_lr` is for automatically scaling LR
15+
# based on the actual training batch size.
16+
auto_scale_lr = dict(base_batch_size=128)
17+
18+
train_pipeline = [
19+
dict(type='LoadImageFromFile'),
20+
dict(type='RandomResizedCrop', scale=32),
21+
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
22+
dict(type='PackClsInputs'),
23+
]
24+
25+
train_dataloader = dict(
26+
batch_size=256,
27+
num_workers=5,
28+
dataset=dict(
29+
type='ImageNet',
30+
data_root='/cache/data/imagenet/',
31+
data_prefix='train',
32+
pipeline=train_pipeline),
33+
sampler=dict(type='DefaultSampler', shuffle=True),
34+
)
35+
36+
test_pipeline = [
37+
dict(type='PackClsInputs'),
38+
]
39+
40+
val_dataloader = dict(
41+
batch_size=16,
42+
num_workers=2,
43+
dataset=dict(
44+
type='CIFAR10',
45+
data_prefix='/cache/data/cifar',
46+
test_mode=True,
47+
pipeline=test_pipeline),
48+
sampler=dict(type='DefaultSampler', shuffle=False),
49+
)
50+
val_evaluator = dict(type='Accuracy', topk=(1, ))
51+
52+
test_dataloader = val_dataloader
53+
test_evaluator = val_evaluator
54+
55+
teacher_ckpt = '/cache/models/resnet_model.pth' # noqa: E501
56+
57+
model = dict(
58+
_scope_='mmrazor',
59+
type='DFNDDistill',
60+
calculate_student_loss=False,
61+
data_preprocessor=dict(
62+
type='ImgDataPreprocessor',
63+
# RGB format normalization parameters
64+
mean=[123.675, 116.28, 103.53],
65+
std=[58.395, 57.12, 57.375],
66+
# convert image from BGR to RGB
67+
bgr_to_rgb=True),
68+
val_data_preprocessor=dict(
69+
type='ImgDataPreprocessor',
70+
# RGB format normalization parameters
71+
mean=[125.307, 122.961, 113.8575],
72+
std=[51.5865, 50.847, 51.255],
73+
# convert image from BGR to RGB
74+
bgr_to_rgb=False),
75+
architecture=dict(
76+
cfg_path='mmcls::resnet/resnet18_8xb16_cifar10.py', pretrained=False),
77+
teacher=dict(
78+
cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py', pretrained=False),
79+
teacher_ckpt=teacher_ckpt,
80+
distiller=dict(
81+
type='ConfigurableDistiller',
82+
student_recorders=dict(
83+
fc=dict(type='ModuleOutputs', source='head.fc')),
84+
teacher_recorders=dict(
85+
fc=dict(type='ModuleOutputs', source='head.fc')),
86+
distill_losses=dict(
87+
loss_kl=dict(
88+
type='DFNDLoss',
89+
tau=4,
90+
loss_weight=1,
91+
num_classes=10,
92+
batch_select=0.5)),
93+
loss_forward_mappings=dict(
94+
loss_kl=dict(
95+
preds_S=dict(from_student=True, recorder='fc'),
96+
preds_T=dict(from_student=False, recorder='fc')))))
97+
98+
find_unused_parameters = True
99+
100+
val_cfg = dict(type='mmrazor.DFNDValLoop')

mmrazor/engine/runner/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .autoslim_greedy_search_loop import AutoSlimGreedySearchLoop
33
from .darts_loop import DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop
4-
from .distill_val_loop import SelfDistillValLoop, SingleTeacherDistillValLoop
4+
from .distill_val_loop import (DFNDValLoop, SelfDistillValLoop,
5+
SingleTeacherDistillValLoop)
56
from .evolution_search_loop import EvolutionSearchLoop
67
from .iteprune_val_loop import ItePruneValLoop
78
from .quantization_loops import (LSQEpochBasedLoop, PTQLoop, QATEpochBasedLoop,
@@ -15,5 +16,5 @@
1516
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
1617
'GreedySamplerTrainLoop', 'SubnetValLoop', 'SelfDistillValLoop',
1718
'ItePruneValLoop', 'AutoSlimGreedySearchLoop', 'QATEpochBasedLoop',
18-
'PTQLoop', 'LSQEpochBasedLoop', 'QATValLoop'
19+
'PTQLoop', 'LSQEpochBasedLoop', 'QATValLoop', 'DFNDValLoop'
1920
]

mmrazor/engine/runner/distill_val_loop.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,38 @@ def run(self):
125125

126126
self.runner.call_hook('after_val_epoch', metrics=student_metrics)
127127
self.runner.call_hook('after_val')
128+
129+
130+
@LOOPS.register_module()
131+
class DFNDValLoop(SingleTeacherDistillValLoop):
132+
"""Validation loop for DFND. DFND requires different dataset for training
133+
and validation.
134+
135+
Args:
136+
runner (Runner): A reference of runner.
137+
dataloader (Dataloader or dict): A dataloader object or a dict to
138+
build a dataloader.
139+
evaluator (Evaluator or dict or list): Used for computing metrics.
140+
fp16 (bool): Whether to enable fp16 validation. Defaults to
141+
False.
142+
"""
143+
144+
def __init__(self,
145+
runner,
146+
dataloader: Union[DataLoader, Dict],
147+
evaluator: Union[Evaluator, Dict, List],
148+
fp16: bool = False) -> None:
149+
super().__init__(runner, dataloader, evaluator, fp16)
150+
if self.runner.distributed:
151+
assert hasattr(self.runner.model.module, 'teacher')
152+
# TODO: remove hard code after mmcls add data_preprocessor
153+
data_preprocessor = self.runner.model.module.val_data_preprocessor
154+
self.teacher = self.runner.model.module.teacher
155+
self.teacher.data_preprocessor = data_preprocessor
156+
157+
else:
158+
assert hasattr(self.runner.model, 'teacher')
159+
# TODO: remove hard code after mmcls add data_preprocessor
160+
data_preprocessor = self.runner.model.val_data_preprocessor
161+
self.teacher = self.runner.model.teacher
162+
self.teacher.data_preprocessor = data_preprocessor
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .datafree_distillation import (DAFLDataFreeDistillation,
33
DataFreeDistillation)
4+
from .dfnd_distill import DFNDDistill
45
from .fpn_teacher_distill import FpnTeacherDistill
56
from .overhaul_feature_distillation import OverhaulFeatureDistillation
67
from .self_distill import SelfDistill
@@ -9,5 +10,5 @@
910
__all__ = [
1011
'SelfDistill', 'SingleTeacherDistill', 'FpnTeacherDistill',
1112
'DataFreeDistillation', 'DAFLDataFreeDistillation',
12-
'OverhaulFeatureDistillation'
13+
'OverhaulFeatureDistillation', 'DFNDDistill'
1314
]

0 commit comments

Comments
 (0)