Skip to content

Commit a27952d

Browse files
gaoyang07gaoyang07liukai
authored
[Improvement] Update NasMutator to build search_space in NAS (#426)
* update space_mixin * update NAS algorithms with SpaceMixin * update pruning algorithms with SpaceMixin * fix ut * fix comments * revert _load_fix_subnet_by_mutator * fix dcff test * add ut for registry * update autoslim_greedy_search * fix repeat-mutables bug * fix slice_weight in export_fix_subnet * Update NasMutator: 1. unify mutators for NAS algorithms as the NasMutator; 2. regard ChannelMutator as pruning-specified; 3. remove value_mutators & module_mutators; 4. set GroupMixin only for NAS; 5. revert all changes in ChannelMutator. * update NAS algorithms using NasMutator * update channel mutator * update one_shot_channel_mutator * fix comments * update UT for NasMutator * fix isort version * fix comments --------- Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: liukai <your_email@abc.example>
1 parent b750375 commit a27952d

File tree

71 files changed

+1126
-1686
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+1126
-1686
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ repos:
55
rev: 4.0.1
66
hooks:
77
- id: flake8
8-
- repo: https://github.com/timothycrosley/isort
9-
rev: 5.10.1
8+
- repo: https://github.com/PyCQA/isort
9+
rev: 5.11.5
1010
hooks:
1111
- id: isort
1212
- repo: https://github.com/pre-commit/mirrors-yapf

configs/_base_/settings/cifar10_darts_supernet.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -48,36 +48,26 @@
4848

4949
# optimizer
5050
optim_wrapper = dict(
51+
constructor='mmrazor.SeparateOptimWrapperConstructor',
5152
architecture=dict(
52-
type='mmcls.SGD', lr=0.025, momentum=0.9, weight_decay=3e-4),
53-
mutator=dict(type='mmcls.Adam', lr=3e-4, weight_decay=1e-3),
54-
clip_grad=dict(max_norm=5, norm_type=2))
53+
optimizer=dict(
54+
type='mmcls.SGD', lr=0.025, momentum=0.9, weight_decay=3e-4),
55+
clip_grad=dict(max_norm=5, norm_type=2)),
56+
mutator=dict(
57+
optimizer=dict(type='mmcls.Adam', lr=3e-4, weight_decay=1e-3)))
5558

59+
search_epochs = 50
5660
# leanring policy
57-
# TODO support different optim use different scheduler (wait mmengine)
5861
param_scheduler = [
5962
dict(
6063
type='mmcls.CosineAnnealingLR',
61-
T_max=50,
64+
T_max=search_epochs,
6265
eta_min=1e-3,
6366
begin=0,
64-
end=50),
67+
end=search_epochs),
6568
]
66-
# param_scheduler = dict(
67-
# architecture = dict(
68-
# type='mmcls.CosineAnnealingLR',
69-
# T_max=50,
70-
# eta_min=1e-3,
71-
# begin=0,
72-
# end=50),
73-
# mutator = dict(
74-
# type='mmcls.ConstantLR',
75-
# factor=1,
76-
# begin=0,
77-
# end=50))
7869

7970
# train, val, test setting
80-
# TODO split cifar dataset
8171
train_cfg = dict(
8272
type='mmrazor.DartsEpochBasedTrainLoop',
8373
mutator_dataloader=dict(
@@ -92,7 +82,7 @@
9282
sampler=dict(type='mmcls.DefaultSampler', shuffle=True),
9383
persistent_workers=True,
9484
),
95-
max_epochs=50)
85+
max_epochs=search_epochs)
9686

9787
val_cfg = dict() # validate each epoch
9888
test_cfg = dict() # dataset settings

configs/nas/mmcls/autoformer/autoformer_supernet_32xb256_in1k.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,7 @@
5353
type='mmrazor.Autoformer',
5454
architecture=supernet,
5555
fix_subnet=None,
56-
mutators=dict(
57-
channel_mutator=dict(
58-
type='mmrazor.OneShotChannelMutator',
59-
channel_unit_cfg={
60-
'type': 'OneShotMutableChannelUnit',
61-
'default_args': {
62-
'unit_predefined': True
63-
}
64-
},
65-
parse_cfg={'type': 'Predefined'}),
66-
value_mutator=dict(type='mmrazor.DynamicValueMutator')))
56+
mutator=dict(type='mmrazor.NasMutator'))
6757

6858
# runtime setting
6959
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]

configs/nas/mmcls/bignas/attentive_mobilenet_supernet_32xb64_in1k.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,7 @@
4444
loss_kl=dict(
4545
preds_S=dict(recorder='fc', from_student=True),
4646
preds_T=dict(recorder='fc', from_student=False)))),
47-
mutators=dict(
48-
channel_mutator=dict(
49-
type='mmrazor.OneShotChannelMutator',
50-
channel_unit_cfg={
51-
'type': 'OneShotMutableChannelUnit',
52-
'default_args': {
53-
'unit_predefined': True
54-
}
55-
},
56-
parse_cfg={'type': 'Predefined'}),
57-
value_mutator=dict(type='DynamicValueMutator')))
47+
mutators=dict(type='mmrazor.NasMutator'))
5848

5949
model_wrapper_cfg = dict(
6050
type='mmrazor.BigNASDDP',

configs/nas/mmcls/darts/darts_supernet_unroll_1xb96_cifar10.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
'mmcls::_base_/default_runtime.py',
55
]
66

7-
# model
8-
mutator = dict(type='mmrazor.DiffModuleMutator')
7+
custom_hooks = [
8+
dict(type='mmrazor.DumpSubnetHook', interval=10, by_epoch=True)
9+
]
910

11+
# model
1012
model = dict(
1113
type='mmrazor.Darts',
1214
architecture=dict(
@@ -20,24 +22,12 @@
2022
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
2123
topk=(1, 5),
2224
cal_acc=True)),
23-
mutator=dict(type='mmrazor.DiffModuleMutator'),
25+
mutator=dict(type='mmrazor.NasMutator'),
2426
unroll=True)
2527

2628
model_wrapper_cfg = dict(
2729
type='mmrazor.DartsDDP',
2830
broadcast_buffers=False,
2931
find_unused_parameters=False)
3032

31-
# TRAINING
32-
optim_wrapper = dict(
33-
_delete_=True,
34-
constructor='mmrazor.SeparateOptimWrapperConstructor',
35-
architecture=dict(
36-
type='OptimWrapper',
37-
optimizer=dict(type='SGD', lr=0.025, momentum=0.9, weight_decay=3e-4),
38-
clip_grad=dict(max_norm=5, norm_type=2)),
39-
mutator=dict(
40-
type='OptimWrapper',
41-
optimizer=dict(type='Adam', lr=3e-4, weight_decay=1e-3)))
42-
4333
find_unused_parameter = False

configs/nas/mmcls/dsnas/dsnas_supernet_8xb128_in1k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
mode='original',
2424
loss_weight=1.0),
2525
topk=(1, 5))),
26-
mutator=dict(type='mmrazor.DiffModuleMutator'),
26+
mutator=dict(type='mmrazor.NasMutator'),
2727
pretrain_epochs=15,
2828
finetune_epochs=_base_.search_epochs,
2929
)

configs/nas/mmcls/onceforall/ofa_mobilenet_supernet_32xb64_in1k.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,7 @@
4343
loss_kl=dict(
4444
preds_S=dict(recorder='fc', from_student=True),
4545
preds_T=dict(recorder='fc', from_student=False)))),
46-
mutators=dict(
47-
channel_mutator=dict(
48-
type='mmrazor.OneShotChannelMutator',
49-
channel_unit_cfg={
50-
'type': 'OneShotMutableChannelUnit',
51-
'default_args': {
52-
'unit_predefined': True
53-
}
54-
},
55-
parse_cfg={'type': 'Predefined'}),
56-
value_mutator=dict(type='DynamicValueMutator')))
46+
mutators=dict(type='mmrazor.NasMutator'))
5747

5848
model_wrapper_cfg = dict(
5949
type='mmrazor.BigNASDDP',

configs/nas/mmcls/spos/spos_mobilenet_supernet_8xb128_in1k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,6 @@
2525
model = dict(
2626
type='mmrazor.SPOS',
2727
architecture=supernet,
28-
mutator=dict(type='mmrazor.OneShotModuleMutator'))
28+
mutator=dict(type='mmrazor.NasMutator'))
2929

3030
find_unused_parameters = True

configs/nas/mmcls/spos/spos_shufflenet_supernet_8xb128_in1k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,6 @@
2525
model = dict(
2626
type='mmrazor.SPOS',
2727
architecture=supernet,
28-
mutator=dict(type='mmrazor.OneShotModuleMutator'))
28+
mutator=dict(type='mmrazor.NasMutator'))
2929

3030
find_unused_parameters = True

configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_supernet_coco_1x.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,6 @@
2525
_delete_=True,
2626
type='mmrazor.SPOS',
2727
architecture=supernet,
28-
mutator=dict(type='mmrazor.OneShotModuleMutator'))
28+
mutator=dict(type='mmrazor.NasMutator'))
2929

3030
find_unused_parameters = True

configs/nas/mmdet/detnas/detnas_retina_shufflenet_supernet_coco_1x.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,6 @@
2222
_delete_=True,
2323
type='mmrazor.SPOS',
2424
architecture=supernet,
25-
mutator=dict(type='mmrazor.OneShotModuleMutator'))
25+
mutator=dict(type='mmrazor.NasMutator'))
2626

2727
find_unused_parameters = True

configs/pruning/mmcls/dcff/dcff_resnet_8xb32_in1k.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@
7676
type='ChannelAnalyzer',
7777
demo_input=(1, 3, 224, 224),
7878
tracer_type='BackwardTracer')),
79-
fix_subnet=None,
8079
data_preprocessor=None,
8180
target_pruning_ratio=target_pruning_ratio,
8281
step_freq=1,

mmrazor/engine/hooks/dump_subnet_hook.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import copy
23
import os.path as osp
34
from pathlib import Path
45
from typing import Optional, Sequence, Union
@@ -8,6 +9,9 @@
89
from mmengine.hooks import Hook
910
from mmengine.registry import HOOKS
1011

12+
from mmrazor.models.mutables.base_mutable import BaseMutable
13+
from mmrazor.structures import convert_fix_subnet, export_fix_subnet
14+
1115
DATA_BATCH = Optional[Sequence[dict]]
1216

1317

@@ -103,16 +107,25 @@ def after_train_epoch(self, runner) -> None:
103107

104108
@master_only
105109
def _save_subnet(self, runner) -> None:
106-
"""Save the current subnet and delete outdated subnet.
110+
"""Save the current best subnet.
107111
108112
Args:
109113
runner (Runner): The runner of the training process.
110114
"""
115+
model = runner.model.module if runner.distributed else runner.model
111116

112-
if runner.distributed:
113-
subnet_dict = runner.model.module.search_subnet()
114-
else:
115-
subnet_dict = runner.model.search_subnet()
117+
# delete non-leaf tensor to get deepcopy(model).
118+
# TODO solve the hard case.
119+
for module in model.architecture.modules():
120+
if isinstance(module, BaseMutable):
121+
if hasattr(module, 'arch_weights'):
122+
delattr(module, 'arch_weights')
123+
124+
copied_model = copy.deepcopy(model)
125+
copied_model.mutator.set_choices(copied_model.sample_choices())
126+
127+
subnet_dict = export_fix_subnet(copied_model)[0]
128+
subnet_dict = convert_fix_subnet(subnet_dict)
116129

117130
if self.by_epoch:
118131
subnet_filename = self.args.get(

mmrazor/engine/hooks/estimate_resources_hook.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def export_subnet(self, model) -> torch.nn.Module:
104104
"""
105105
# Avoid circular import
106106
from mmrazor.models.mutables.base_mutable import BaseMutable
107-
from mmrazor.structures import load_fix_subnet
107+
from mmrazor.structures import export_fix_subnet, load_fix_subnet
108108

109109
# delete non-leaf tensor to get deepcopy(model).
110110
# TODO solve the hard case.
@@ -114,7 +114,9 @@ def export_subnet(self, model) -> torch.nn.Module:
114114
delattr(module, 'arch_weights')
115115

116116
copied_model = copy.deepcopy(model)
117-
fix_mutable = copied_model.search_subnet()
118-
load_fix_subnet(copied_model, fix_mutable)
117+
copied_model.mutator.set_choices(copied_model.mutator.sample_choices())
118+
119+
subnet_dict = export_fix_subnet(copied_model)[0]
120+
load_fix_subnet(copied_model, subnet_dict)
119121

120122
return copied_model

mmrazor/engine/runner/autoslim_greedy_search_loop.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torch.utils.data import DataLoader
1212

1313
from mmrazor.registry import LOOPS, TASK_UTILS
14-
from mmrazor.structures import export_fix_subnet
14+
from mmrazor.structures import convert_fix_subnet, export_fix_subnet
1515
from .utils import check_subnet_resources
1616

1717

@@ -68,14 +68,15 @@ def __init__(self,
6868
self.model = runner.model
6969

7070
assert hasattr(self.model, 'mutator')
71-
search_groups = self.model.mutator.search_groups
71+
units = self.model.mutator.mutable_units
72+
7273
self.candidate_choices = {}
73-
for group_id, modules in search_groups.items():
74-
self.candidate_choices[group_id] = modules[0].candidate_choices
74+
for unit in units:
75+
self.candidate_choices[unit.alias] = unit.candidate_choices
7576

7677
self.max_subnet = {}
77-
for group_id, candidate_choices in self.candidate_choices.items():
78-
self.max_subnet[group_id] = len(candidate_choices)
78+
for name, candidate_choices in self.candidate_choices.items():
79+
self.max_subnet[name] = len(candidate_choices)
7980
self.current_subnet = self.max_subnet
8081

8182
current_subnet_choices = self._channel_bins2choices(
@@ -117,7 +118,7 @@ def run(self) -> None:
117118
pruned_subnet[unit_name] -= 1
118119
pruned_subnet_choices = self._channel_bins2choices(
119120
pruned_subnet)
120-
self.model.set_subnet(pruned_subnet_choices)
121+
self.model.mutator.set_choices(pruned_subnet_choices)
121122
metrics = self._val_subnet()
122123
score = metrics[self.score_key] \
123124
if len(metrics) != 0 else 0.
@@ -195,27 +196,16 @@ def _save_searcher_ckpt(self) -> None:
195196

196197
def _save_searched_subnet(self):
197198
"""Save the final searched subnet dict."""
198-
199-
def _convert_fix_subnet(fixed_subnet: Dict[str, Any]):
200-
from mmrazor.utils.typing import DumpChosen
201-
202-
converted_fix_subnet = dict()
203-
for key, val in fixed_subnet.items():
204-
assert isinstance(val, DumpChosen)
205-
converted_fix_subnet[key] = dict(val._asdict())
206-
207-
return converted_fix_subnet
208-
209199
if self.runner.rank != 0:
210200
return
211201
self.runner.logger.info('Search finished:')
212202
for subnet, flops in zip(self.searched_subnet,
213203
self.searched_subnet_flops):
214204
subnet_choice = self._channel_bins2choices(subnet)
215-
self.model.set_subnet(subnet_choice)
205+
self.model.mutator.set_choices(subnet_choice)
216206
fixed_subnet, _ = export_fix_subnet(self.model)
217207
save_name = 'FLOPS_{:.2f}M.yaml'.format(flops)
218-
fixed_subnet = _convert_fix_subnet(fixed_subnet)
208+
fixed_subnet = convert_fix_subnet(fixed_subnet)
219209
fileio.dump(fixed_subnet, osp.join(self.runner.work_dir,
220210
save_name))
221211
self.runner.logger.info(

0 commit comments

Comments
 (0)