Skip to content

Commit 88bbb69

Browse files
committed
[BugFix & Feature]
1.Fix set_max_subnet or set_min_subnet not found exception 2 Add the fleibility to define custom subnet kinds 3.Support to specifiy None with ResorceEstimator Signed-off-by: Ming-Hsuan-Tu <qrnnis2623891@gmail.com>
1 parent 9446b30 commit 88bbb69

File tree

2 files changed

+82
-24
lines changed

2 files changed

+82
-24
lines changed

mmrazor/engine/runner/subnet_val_loop.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ class SubnetValLoop(ValLoop, CalibrateBNMixin):
2121
evaluator (Evaluator or dict or list): Used for computing metrics.
2222
fp16 (bool): Whether to enable fp16 validation. Defaults to
2323
False.
24-
evaluate_fixed_subnet (bool): Whether to evaluate a fixed subnet only
25-
or not. Defaults to False.
24+
fix_subnet_kind (str): fix subnet kinds when evaluate, this would be
25+
`sample_kinds` if not specified
2626
calibrate_sample_num (int): The number of images to compute the true
2727
average of per-batch mean/variance instead of the running average.
2828
Defaults to 4096.
@@ -36,7 +36,7 @@ def __init__(
3636
dataloader: Union[DataLoader, Dict],
3737
evaluator: Union[Evaluator, Dict, List],
3838
fp16: bool = False,
39-
evaluate_fixed_subnet: bool = False,
39+
fix_subnet_kinds: List[str] = [],
4040
calibrate_sample_num: int = 4096,
4141
estimator_cfg: Optional[Dict] = dict(type='mmrazor.ResourceEstimator')
4242
) -> None:
@@ -48,9 +48,18 @@ def __init__(
4848
model = self.runner.model
4949

5050
self.model = model
51-
self.evaluate_fixed_subnet = evaluate_fixed_subnet
51+
if fix_subnet_kinds is None and not hasattr(self.model,
52+
'sample_kinds'):
53+
raise ValueError(
54+
'neither fix_subnet_kinds nor self.model.sample_kinds exists')
55+
56+
self.evaluate_kinds = fix_subnet_kinds if len(
57+
fix_subnet_kinds) > 0 else getattr(self.model, 'sample_kinds')
58+
5259
self.calibrate_sample_num = calibrate_sample_num
53-
self.estimator = TASK_UTILS.build(estimator_cfg)
60+
self.estimator = None
61+
if estimator_cfg:
62+
self.estimator = TASK_UTILS.build(estimator_cfg)
5463

5564
def run(self):
5665
"""Launch validation."""
@@ -59,24 +68,19 @@ def run(self):
5968

6069
all_metrics = dict()
6170

62-
if self.evaluate_fixed_subnet:
71+
for kind in self.evaluate_kinds:
72+
if kind == 'max':
73+
self.model.mutator.set_max_choices()
74+
elif kind == 'min':
75+
self.model.mutator.set_min_choices()
76+
elif 'random' in kind:
77+
self.model.mutator.set_choices(
78+
self.model.mutator.sample_choices())
79+
else:
80+
raise NotImplementedError(f'Unsupported Subnet {kind}')
81+
6382
metrics = self._evaluate_once()
64-
all_metrics.update(add_prefix(metrics, 'fix_subnet'))
65-
elif hasattr(self.model, 'sample_kinds'):
66-
for kind in self.model.sample_kinds:
67-
if kind == 'max':
68-
self.model.mutator.set_max_choices()
69-
metrics = self._evaluate_once()
70-
all_metrics.update(add_prefix(metrics, 'max_subnet'))
71-
elif kind == 'min':
72-
self.model.mutator.set_min_choices()
73-
metrics = self._evaluate_once()
74-
all_metrics.update(add_prefix(metrics, 'min_subnet'))
75-
elif 'random' in kind:
76-
self.model.mutator.set_choices(
77-
self.model.mutator.sample_choices())
78-
metrics = self._evaluate_once()
79-
all_metrics.update(add_prefix(metrics, f'{kind}_subnet'))
83+
all_metrics.update(add_prefix(metrics, f'{kind}_subnet'))
8084

8185
self.runner.call_hook('after_val_epoch', metrics=all_metrics)
8286
self.runner.call_hook('after_val')
@@ -90,7 +94,8 @@ def _evaluate_once(self) -> Dict:
9094
self.run_iter(idx, data_batch)
9195

9296
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
93-
resource_metrics = self.estimator.estimate(self.model)
94-
metrics.update(resource_metrics)
97+
if self.estimator:
98+
resource_metrics = self.estimator.estimate(self.model)
99+
metrics.update(resource_metrics)
95100

96101
return metrics
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from unittest import TestCase
3+
from unittest.mock import MagicMock, call, patch
4+
5+
from mmrazor.engine.runner import SubnetValLoop
6+
7+
8+
class TestSubnetValLoop(TestCase):
9+
10+
def test_subnet_val_loop():
11+
runner = MagicMock()
12+
runner.distributed = False
13+
runner.model = MagicMock()
14+
dataloader = MagicMock()
15+
evaluator = [MagicMock()]
16+
fix_subnet_kinds = ['max', 'min']
17+
loop = SubnetValLoop(
18+
runner,
19+
dataloader,
20+
evaluator,
21+
fix_subnet_kinds=fix_subnet_kinds,
22+
estimator_cfg=None)
23+
runner.train_dataloader = MagicMock()
24+
with patch.object(loop, '_evaluate_once') as evaluate_mock:
25+
evaluate_mock.return_value = dict(acc=10)
26+
all_metrics = dict()
27+
all_metrics['max_subnet.acc'] = 10
28+
all_metrics['min_subnet.acc'] = 10
29+
loop.run()
30+
loop.runner.call_hook.assert_has_calls([
31+
call('before_val'),
32+
call('before_val_epoch'),
33+
call('after_val_epoch', metrics=all_metrics),
34+
call('after_val')
35+
])
36+
evaluate_mock.assert_has_calls([call(), call()])
37+
38+
runner.dataloader = MagicMock()
39+
runner.dataloader.dataset = MagicMock()
40+
loop.dataloader.__iter__.return_value = ['data_batch1']
41+
with patch.object(loop,
42+
'calibrate_bn_statistics') as calibration_bn_mock:
43+
with patch.object(loop, 'run_iter') as run_iter_mock:
44+
eval_result = dict(acc=10)
45+
loop.evaluator.evaluate.return_value = eval_result
46+
result = loop._evaluate_once()
47+
calibration_bn_mock.assert_called_with(
48+
runner.train_dataloader, loop.calibrate_sample_num)
49+
runner.model.eval.assert_called()
50+
run_iter_mock.assert_called_with(0, 'data_batch1')
51+
loop.evaluator.evaluate.assert_called_with(
52+
len(runner.dataloader.dataset))
53+
assert result == eval_result

0 commit comments

Comments
 (0)