Skip to content

Commit 6118cf5

Browse files
committed
[BugFix & Feature]
1.Fix set_max_subnet or set_min_subnet not found exception 2 Add the fleibility to define custom subnet kinds 3.Support to specifiy None with ResorceEstimator Signed-off-by: Ming-Hsuan-Tu <qrnnis2623891@gmail.com>
1 parent 7acc046 commit 6118cf5

File tree

2 files changed

+73
-24
lines changed

2 files changed

+73
-24
lines changed

mmrazor/engine/runner/subnet_val_loop.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ class SubnetValLoop(ValLoop, CalibrateBNMixin):
2121
evaluator (Evaluator or dict or list): Used for computing metrics.
2222
fp16 (bool): Whether to enable fp16 validation. Defaults to
2323
False.
24-
evaluate_fixed_subnet (bool): Whether to evaluate a fixed subnet only
25-
or not. Defaults to False.
24+
fix_subnet_kind (str): fix subnet kinds when evaluate, this would be
25+
`sample_kinds` if not specified
2626
calibrate_sample_num (int): The number of images to compute the true
2727
average of per-batch mean/variance instead of the running average.
2828
Defaults to 4096.
@@ -36,7 +36,7 @@ def __init__(
3636
dataloader: Union[DataLoader, Dict],
3737
evaluator: Union[Evaluator, Dict, List],
3838
fp16: bool = False,
39-
evaluate_fixed_subnet: bool = False,
39+
fix_subnet_kinds: List[str] = [],
4040
calibrate_sample_num: int = 4096,
4141
estimator_cfg: Optional[Dict] = dict(type='mmrazor.ResourceEstimator')
4242
) -> None:
@@ -48,9 +48,13 @@ def __init__(
4848
model = self.runner.model
4949

5050
self.model = model
51-
self.evaluate_fixed_subnet = evaluate_fixed_subnet
51+
self.evaluate_kinds = fix_subnet_kinds if len(
52+
fix_subnet_kinds) > 0 else getattr(self.model, 'sample_kinds')
53+
5254
self.calibrate_sample_num = calibrate_sample_num
53-
self.estimator = TASK_UTILS.build(estimator_cfg)
55+
self.estimator = None
56+
if estimator_cfg:
57+
self.estimator = TASK_UTILS.build(estimator_cfg)
5458

5559
def run(self):
5660
"""Launch validation."""
@@ -59,24 +63,19 @@ def run(self):
5963

6064
all_metrics = dict()
6165

62-
if self.evaluate_fixed_subnet:
66+
for kind in self.evaluate_kinds:
67+
if kind == 'max':
68+
self.model.mutator.set_max_choices()
69+
elif kind == 'min':
70+
self.model.mutator.set_min_choices()
71+
elif 'random' in kind:
72+
self.model.mutator.set_choices(
73+
self.model.mutator.sample_choices())
74+
else:
75+
raise NotImplementedError(f'{kind}')
76+
6377
metrics = self._evaluate_once()
64-
all_metrics.update(add_prefix(metrics, 'fix_subnet'))
65-
elif hasattr(self.model, 'sample_kinds'):
66-
for kind in self.model.sample_kinds:
67-
if kind == 'max':
68-
self.model.set_max_subnet()
69-
metrics = self._evaluate_once()
70-
all_metrics.update(add_prefix(metrics, 'max_subnet'))
71-
elif kind == 'min':
72-
self.model.set_min_subnet()
73-
metrics = self._evaluate_once()
74-
all_metrics.update(add_prefix(metrics, 'min_subnet'))
75-
elif 'random' in kind:
76-
self.model.mutator.set_choices(
77-
self.model.mutator.sample_choices())
78-
metrics = self._evaluate_once()
79-
all_metrics.update(add_prefix(metrics, f'{kind}_subnet'))
78+
all_metrics.update(add_prefix(metrics, f'{kind}_subnet'))
8079

8180
self.runner.call_hook('after_val_epoch', metrics=all_metrics)
8281
self.runner.call_hook('after_val')
@@ -90,7 +89,8 @@ def _evaluate_once(self) -> Dict:
9089
self.run_iter(idx, data_batch)
9190

9291
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
93-
resource_metrics = self.estimator.estimate(self.model)
94-
metrics.update(resource_metrics)
92+
if self.estimator:
93+
resource_metrics = self.estimator.estimate(self.model)
94+
metrics.update(resource_metrics)
9595

9696
return metrics
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from unittest.mock import MagicMock, call, patch
3+
4+
from mmrazor.engine.runner import SubnetValLoop
5+
6+
7+
def test_subnet_val_loop():
8+
runner = MagicMock()
9+
runner.distributed = False
10+
runner.model = MagicMock()
11+
dataloader = MagicMock()
12+
evaluator = [MagicMock()]
13+
fix_subnet_kinds = ['max', 'min']
14+
loop = SubnetValLoop(
15+
runner,
16+
dataloader,
17+
evaluator,
18+
fix_subnet_kinds=fix_subnet_kinds,
19+
estimator_cfg=None)
20+
runner.train_dataloader = MagicMock()
21+
with patch.object(loop, '_evaluate_once') as evaluate_mock:
22+
evaluate_mock.return_value = dict(acc=10)
23+
all_metrics = dict()
24+
all_metrics['max_subnet.acc'] = 10
25+
all_metrics['min_subnet.acc'] = 10
26+
loop.run()
27+
loop.runner.call_hook.assert_has_calls([
28+
call('before_val'),
29+
call('before_val_epoch'),
30+
call('after_val_epoch', metrics=all_metrics),
31+
call('after_val')
32+
])
33+
evaluate_mock.assert_has_calls([(call(), call())])
34+
35+
runner.dataloader = MagicMock()
36+
runner.dataloader.dataset = MagicMock()
37+
loop.dataloader.__iter__.return_value = ['data_batch1']
38+
with patch.object(loop, 'calibrate_bn_statistics') as calibration_bn_mock:
39+
with patch.object(loop, 'run_iter') as run_iter_mock:
40+
eval_result = dict(acc=10)
41+
loop.evaluator.evaluate.return_value = eval_result
42+
result = loop._evaluate_once()
43+
calibration_bn_mock.assert_called_with(runner.train_dataloader,
44+
loop.calibrate_sample_num)
45+
runner.model.eval.assert_called()
46+
run_iter_mock.assert_called_with(0, 'data_batch1')
47+
loop.evaluator.evaluate.assert_called_with(
48+
len(runner.dataloader.dataset))
49+
assert result == eval_result

0 commit comments

Comments
 (0)