Skip to content

Commit eb856cb

Browse files
committed
[BugFix & Feature]
1.Fix set_max_subnet or set_min_subnet not found exception 2 Add the fleibility to define custom subnet kinds 3.Support to specifiy None with ResorceEstimator Signed-off-by: Ming-Hsuan-Tu <qrnnis2623891@gmail.com>
1 parent 9446b30 commit eb856cb

File tree

2 files changed

+79
-24
lines changed

2 files changed

+79
-24
lines changed

mmrazor/engine/runner/subnet_val_loop.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ class SubnetValLoop(ValLoop, CalibrateBNMixin):
2121
evaluator (Evaluator or dict or list): Used for computing metrics.
2222
fp16 (bool): Whether to enable fp16 validation. Defaults to
2323
False.
24-
evaluate_fixed_subnet (bool): Whether to evaluate a fixed subnet only
25-
or not. Defaults to False.
24+
fix_subnet_kind (str): fix subnet kinds when evaluate, this would be
25+
`sample_kinds` if not specified
2626
calibrate_sample_num (int): The number of images to compute the true
2727
average of per-batch mean/variance instead of the running average.
2828
Defaults to 4096.
@@ -36,7 +36,7 @@ def __init__(
3636
dataloader: Union[DataLoader, Dict],
3737
evaluator: Union[Evaluator, Dict, List],
3838
fp16: bool = False,
39-
evaluate_fixed_subnet: bool = False,
39+
fix_subnet_kinds: List[str] = [],
4040
calibrate_sample_num: int = 4096,
4141
estimator_cfg: Optional[Dict] = dict(type='mmrazor.ResourceEstimator')
4242
) -> None:
@@ -48,9 +48,18 @@ def __init__(
4848
model = self.runner.model
4949

5050
self.model = model
51-
self.evaluate_fixed_subnet = evaluate_fixed_subnet
51+
if fix_subnet_kinds is None and not hasattr(self.model,
52+
'sample_kinds'):
53+
raise ValueError(
54+
'neither fix_subnet_kinds nor self.model.sample_kinds exists')
55+
56+
self.evaluate_kinds = fix_subnet_kinds if len(
57+
fix_subnet_kinds) > 0 else getattr(self.model, 'sample_kinds')
58+
5259
self.calibrate_sample_num = calibrate_sample_num
53-
self.estimator = TASK_UTILS.build(estimator_cfg)
60+
self.estimator = None
61+
if estimator_cfg:
62+
self.estimator = TASK_UTILS.build(estimator_cfg)
5463

5564
def run(self):
5665
"""Launch validation."""
@@ -59,24 +68,19 @@ def run(self):
5968

6069
all_metrics = dict()
6170

62-
if self.evaluate_fixed_subnet:
71+
for kind in self.evaluate_kinds:
72+
if kind == 'max':
73+
self.model.mutator.set_max_choices()
74+
elif kind == 'min':
75+
self.model.mutator.set_min_choices()
76+
elif 'random' in kind:
77+
self.model.mutator.set_choices(
78+
self.model.mutator.sample_choices())
79+
else:
80+
raise NotImplementedError(f'Unsupported Subnet {kind}')
81+
6382
metrics = self._evaluate_once()
64-
all_metrics.update(add_prefix(metrics, 'fix_subnet'))
65-
elif hasattr(self.model, 'sample_kinds'):
66-
for kind in self.model.sample_kinds:
67-
if kind == 'max':
68-
self.model.mutator.set_max_choices()
69-
metrics = self._evaluate_once()
70-
all_metrics.update(add_prefix(metrics, 'max_subnet'))
71-
elif kind == 'min':
72-
self.model.mutator.set_min_choices()
73-
metrics = self._evaluate_once()
74-
all_metrics.update(add_prefix(metrics, 'min_subnet'))
75-
elif 'random' in kind:
76-
self.model.mutator.set_choices(
77-
self.model.mutator.sample_choices())
78-
metrics = self._evaluate_once()
79-
all_metrics.update(add_prefix(metrics, f'{kind}_subnet'))
83+
all_metrics.update(add_prefix(metrics, f'{kind}_subnet'))
8084

8185
self.runner.call_hook('after_val_epoch', metrics=all_metrics)
8286
self.runner.call_hook('after_val')
@@ -90,7 +94,8 @@ def _evaluate_once(self) -> Dict:
9094
self.run_iter(idx, data_batch)
9195

9296
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
93-
resource_metrics = self.estimator.estimate(self.model)
94-
metrics.update(resource_metrics)
97+
if self.estimator:
98+
resource_metrics = self.estimator.estimate(self.model)
99+
metrics.update(resource_metrics)
95100

96101
return metrics
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from unittest import TestCase
3+
from unittest.mock import MagicMock, call, patch
4+
5+
from mmrazor.engine.runner import SubnetValLoop
6+
7+
class TestSubnetValLoop(TestCase):
8+
def test_subnet_val_loop():
9+
runner = MagicMock()
10+
runner.distributed = False
11+
runner.model = MagicMock()
12+
dataloader = MagicMock()
13+
evaluator = [MagicMock()]
14+
fix_subnet_kinds = ['max', 'min']
15+
loop = SubnetValLoop(
16+
runner,
17+
dataloader,
18+
evaluator,
19+
fix_subnet_kinds=fix_subnet_kinds,
20+
estimator_cfg=None)
21+
runner.train_dataloader = MagicMock()
22+
with patch.object(loop, '_evaluate_once') as evaluate_mock:
23+
evaluate_mock.return_value = dict(acc=10)
24+
all_metrics = dict()
25+
all_metrics['max_subnet.acc'] = 10
26+
all_metrics['min_subnet.acc'] = 10
27+
loop.run()
28+
loop.runner.call_hook.assert_has_calls([
29+
call('before_val'),
30+
call('before_val_epoch'),
31+
call('after_val_epoch', metrics=all_metrics),
32+
call('after_val')
33+
])
34+
evaluate_mock.assert_has_calls([call(), call()])
35+
36+
runner.dataloader = MagicMock()
37+
runner.dataloader.dataset = MagicMock()
38+
loop.dataloader.__iter__.return_value = ['data_batch1']
39+
with patch.object(loop, 'calibrate_bn_statistics') as calibration_bn_mock:
40+
with patch.object(loop, 'run_iter') as run_iter_mock:
41+
eval_result = dict(acc=10)
42+
loop.evaluator.evaluate.return_value = eval_result
43+
result = loop._evaluate_once()
44+
calibration_bn_mock.assert_called_with(runner.train_dataloader,
45+
loop.calibrate_sample_num)
46+
runner.model.eval.assert_called()
47+
run_iter_mock.assert_called_with(0, 'data_batch1')
48+
loop.evaluator.evaluate.assert_called_with(
49+
len(runner.dataloader.dataset))
50+
assert result == eval_result

0 commit comments

Comments
 (0)