Skip to content

Commit b5bd199

Browse files
twmhtalec.tu
authored and
alec.tu
committed
[BugFix & Feature]
1.Fix set_max_subnet or set_min_subnet not found exception 2 Add the fleibility to define custom subnet kinds 3.Support to specifiy None with ResorceEstimator Signed-off-by: Ming-Hsuan-Tu <qrnnis2623891@gmail.com>
1 parent 90c5435 commit b5bd199

File tree

2 files changed

+118
-24
lines changed

2 files changed

+118
-24
lines changed

mmrazor/engine/runner/subnet_val_loop.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ class SubnetValLoop(ValLoop, CalibrateBNMixin):
2121
evaluator (Evaluator or dict or list): Used for computing metrics.
2222
fp16 (bool): Whether to enable fp16 validation. Defaults to
2323
False.
24-
evaluate_fixed_subnet (bool): Whether to evaluate a fixed subnet only
25-
or not. Defaults to False.
24+
fix_subnet_kind (str): fix subnet kinds when evaluate, this would be
25+
`sample_kinds` if not specified
2626
calibrate_sample_num (int): The number of images to compute the true
2727
average of per-batch mean/variance instead of the running average.
2828
Defaults to 4096.
@@ -36,7 +36,7 @@ def __init__(
3636
dataloader: Union[DataLoader, Dict],
3737
evaluator: Union[Evaluator, Dict, List],
3838
fp16: bool = False,
39-
evaluate_fixed_subnet: bool = False,
39+
fix_subnet_kinds: List[str] = [],
4040
calibrate_sample_num: int = 4096,
4141
estimator_cfg: Optional[Dict] = dict(type='mmrazor.ResourceEstimator')
4242
) -> None:
@@ -48,9 +48,18 @@ def __init__(
4848
model = self.runner.model
4949

5050
self.model = model
51-
self.evaluate_fixed_subnet = evaluate_fixed_subnet
51+
if len(fix_subnet_kinds) == 0 and not hasattr(self.model,
52+
'sample_kinds'):
53+
raise ValueError(
54+
'neither fix_subnet_kinds nor self.model.sample_kinds exists')
55+
56+
self.evaluate_kinds = fix_subnet_kinds if len(
57+
fix_subnet_kinds) > 0 else getattr(self.model, 'sample_kinds')
58+
5259
self.calibrate_sample_num = calibrate_sample_num
53-
self.estimator = TASK_UTILS.build(estimator_cfg)
60+
self.estimator = None
61+
if estimator_cfg:
62+
self.estimator = TASK_UTILS.build(estimator_cfg)
5463

5564
def run(self):
5665
"""Launch validation."""
@@ -59,24 +68,19 @@ def run(self):
5968

6069
all_metrics = dict()
6170

62-
if self.evaluate_fixed_subnet:
71+
for kind in self.evaluate_kinds:
72+
if kind == 'max':
73+
self.model.mutator.set_max_choices()
74+
elif kind == 'min':
75+
self.model.mutator.set_min_choices()
76+
elif 'random' in kind:
77+
self.model.mutator.set_choices(
78+
self.model.mutator.sample_choices())
79+
else:
80+
raise NotImplementedError(f'Unsupported Subnet {kind}')
81+
6382
metrics = self._evaluate_once()
64-
all_metrics.update(add_prefix(metrics, 'fix_subnet'))
65-
elif hasattr(self.model, 'sample_kinds'):
66-
for kind in self.model.sample_kinds:
67-
if kind == 'max':
68-
self.model.mutator.set_max_choices()
69-
metrics = self._evaluate_once()
70-
all_metrics.update(add_prefix(metrics, 'max_subnet'))
71-
elif kind == 'min':
72-
self.model.mutator.set_min_choices()
73-
metrics = self._evaluate_once()
74-
all_metrics.update(add_prefix(metrics, 'min_subnet'))
75-
elif 'random' in kind:
76-
self.model.mutator.set_choices(
77-
self.model.mutator.sample_choices())
78-
metrics = self._evaluate_once()
79-
all_metrics.update(add_prefix(metrics, f'{kind}_subnet'))
83+
all_metrics.update(add_prefix(metrics, f'{kind}_subnet'))
8084

8185
self.runner.call_hook('after_val_epoch', metrics=all_metrics)
8286
self.runner.call_hook('after_val')
@@ -90,7 +94,8 @@ def _evaluate_once(self) -> Dict:
9094
self.run_iter(idx, data_batch)
9195

9296
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
93-
resource_metrics = self.estimator.estimate(self.model)
94-
metrics.update(resource_metrics)
97+
if self.estimator:
98+
resource_metrics = self.estimator.estimate(self.model)
99+
metrics.update(resource_metrics)
95100

96101
return metrics
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from unittest import TestCase
3+
from unittest.mock import MagicMock, call, patch
4+
5+
import pytest
6+
7+
from mmrazor.engine.runner import SubnetValLoop
8+
9+
10+
class TestSubnetValLoop(TestCase):
11+
12+
def test_subnet_val_loop(self):
13+
runner = MagicMock()
14+
runner.distributed = False
15+
runner.model = MagicMock()
16+
dataloader = MagicMock()
17+
evaluator = [MagicMock()]
18+
fix_subnet_kinds = ['max', 'min', 'random']
19+
loop = SubnetValLoop(
20+
runner, dataloader, evaluator, fix_subnet_kinds=fix_subnet_kinds)
21+
22+
loop.estimator = MagicMock()
23+
loop.estimator.estimate.return_value = dict(flops=10)
24+
25+
runner.train_dataloader = MagicMock()
26+
with patch.object(loop, '_evaluate_once') as evaluate_mock:
27+
evaluate_mock.return_value = dict(acc=10)
28+
all_metrics = dict()
29+
all_metrics['max_subnet.acc'] = 10
30+
all_metrics['min_subnet.acc'] = 10
31+
all_metrics['random_subnet.acc'] = 10
32+
loop.run()
33+
loop.runner.call_hook.assert_has_calls([
34+
call('before_val'),
35+
call('before_val_epoch'),
36+
call('after_val_epoch', metrics=all_metrics),
37+
call('after_val')
38+
])
39+
evaluate_mock.assert_has_calls([call(), call(), call()])
40+
41+
runner.dataloader = MagicMock()
42+
runner.dataloader.dataset = MagicMock()
43+
loop.dataloader.__iter__.return_value = ['data_batch1']
44+
with patch.object(loop,
45+
'calibrate_bn_statistics') as calibration_bn_mock:
46+
with patch.object(loop, 'run_iter') as run_iter_mock:
47+
eval_result = dict(acc=10)
48+
loop.evaluator.evaluate.return_value = eval_result
49+
result = loop._evaluate_once()
50+
calibration_bn_mock.assert_called_with(
51+
runner.train_dataloader, loop.calibrate_sample_num)
52+
runner.model.eval.assert_called()
53+
run_iter_mock.assert_called_with(0, 'data_batch1')
54+
loop.evaluator.evaluate.assert_called_with(
55+
len(runner.dataloader.dataset))
56+
assert result == eval_result
57+
loop.estimator.estimate.assert_called()
58+
59+
def test_invalid_kind(self):
60+
runner = MagicMock()
61+
runner.distributed = False
62+
runner.model = MagicMock()
63+
dataloader = MagicMock()
64+
evaluator = [MagicMock()]
65+
fix_subnet_kinds = ['invalid']
66+
loop = SubnetValLoop(
67+
runner,
68+
dataloader,
69+
evaluator,
70+
fix_subnet_kinds=fix_subnet_kinds,
71+
estimator_cfg=None)
72+
with pytest.raises(NotImplementedError):
73+
loop.run()
74+
75+
def test_subnet_val_loop_with_invalid_value(self):
76+
runner = MagicMock()
77+
runner.model.module = MagicMock()
78+
runner.model.module.__setattr__('sample_kinds', None)
79+
del runner.model.module.sample_kinds
80+
dataloader = MagicMock()
81+
evaluator = [MagicMock()]
82+
fix_subnet_kinds = []
83+
with pytest.raises(ValueError):
84+
SubnetValLoop(
85+
runner,
86+
dataloader,
87+
evaluator,
88+
fix_subnet_kinds=fix_subnet_kinds,
89+
estimator_cfg=None)

0 commit comments

Comments
 (0)