From b5bd199c287ff4f5407f5def4fdd2acfa7c8dfc5 Mon Sep 17 00:00:00 2001 From: Ming-Hsuan-Tu Date: Sun, 26 Feb 2023 19:58:23 +0800 Subject: [PATCH] [BugFix & Feature] 1.Fix set_max_subnet or set_min_subnet not found exception 2 Add the fleibility to define custom subnet kinds 3.Support to specifiy None with ResorceEstimator Signed-off-by: Ming-Hsuan-Tu --- mmrazor/engine/runner/subnet_val_loop.py | 53 ++++++----- .../test_runner/test_subnet_val_loop.py | 89 +++++++++++++++++++ 2 files changed, 118 insertions(+), 24 deletions(-) create mode 100644 tests/test_engine/test_runner/test_subnet_val_loop.py diff --git a/mmrazor/engine/runner/subnet_val_loop.py b/mmrazor/engine/runner/subnet_val_loop.py index d61e2747f..551cc039f 100644 --- a/mmrazor/engine/runner/subnet_val_loop.py +++ b/mmrazor/engine/runner/subnet_val_loop.py @@ -21,8 +21,8 @@ class SubnetValLoop(ValLoop, CalibrateBNMixin): evaluator (Evaluator or dict or list): Used for computing metrics. fp16 (bool): Whether to enable fp16 validation. Defaults to False. - evaluate_fixed_subnet (bool): Whether to evaluate a fixed subnet only - or not. Defaults to False. + fix_subnet_kind (str): fix subnet kinds when evaluate, this would be + `sample_kinds` if not specified calibrate_sample_num (int): The number of images to compute the true average of per-batch mean/variance instead of the running average. Defaults to 4096. @@ -36,7 +36,7 @@ def __init__( dataloader: Union[DataLoader, Dict], evaluator: Union[Evaluator, Dict, List], fp16: bool = False, - evaluate_fixed_subnet: bool = False, + fix_subnet_kinds: List[str] = [], calibrate_sample_num: int = 4096, estimator_cfg: Optional[Dict] = dict(type='mmrazor.ResourceEstimator') ) -> None: @@ -48,9 +48,18 @@ def __init__( model = self.runner.model self.model = model - self.evaluate_fixed_subnet = evaluate_fixed_subnet + if len(fix_subnet_kinds) == 0 and not hasattr(self.model, + 'sample_kinds'): + raise ValueError( + 'neither fix_subnet_kinds nor self.model.sample_kinds exists') + + self.evaluate_kinds = fix_subnet_kinds if len( + fix_subnet_kinds) > 0 else getattr(self.model, 'sample_kinds') + self.calibrate_sample_num = calibrate_sample_num - self.estimator = TASK_UTILS.build(estimator_cfg) + self.estimator = None + if estimator_cfg: + self.estimator = TASK_UTILS.build(estimator_cfg) def run(self): """Launch validation.""" @@ -59,24 +68,19 @@ def run(self): all_metrics = dict() - if self.evaluate_fixed_subnet: + for kind in self.evaluate_kinds: + if kind == 'max': + self.model.mutator.set_max_choices() + elif kind == 'min': + self.model.mutator.set_min_choices() + elif 'random' in kind: + self.model.mutator.set_choices( + self.model.mutator.sample_choices()) + else: + raise NotImplementedError(f'Unsupported Subnet {kind}') + metrics = self._evaluate_once() - all_metrics.update(add_prefix(metrics, 'fix_subnet')) - elif hasattr(self.model, 'sample_kinds'): - for kind in self.model.sample_kinds: - if kind == 'max': - self.model.mutator.set_max_choices() - metrics = self._evaluate_once() - all_metrics.update(add_prefix(metrics, 'max_subnet')) - elif kind == 'min': - self.model.mutator.set_min_choices() - metrics = self._evaluate_once() - all_metrics.update(add_prefix(metrics, 'min_subnet')) - elif 'random' in kind: - self.model.mutator.set_choices( - self.model.mutator.sample_choices()) - metrics = self._evaluate_once() - all_metrics.update(add_prefix(metrics, f'{kind}_subnet')) + all_metrics.update(add_prefix(metrics, f'{kind}_subnet')) self.runner.call_hook('after_val_epoch', metrics=all_metrics) self.runner.call_hook('after_val') @@ -90,7 +94,8 @@ def _evaluate_once(self) -> Dict: self.run_iter(idx, data_batch) metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) - resource_metrics = self.estimator.estimate(self.model) - metrics.update(resource_metrics) + if self.estimator: + resource_metrics = self.estimator.estimate(self.model) + metrics.update(resource_metrics) return metrics diff --git a/tests/test_engine/test_runner/test_subnet_val_loop.py b/tests/test_engine/test_runner/test_subnet_val_loop.py new file mode 100644 index 000000000..a7e3e89e4 --- /dev/null +++ b/tests/test_engine/test_runner/test_subnet_val_loop.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase +from unittest.mock import MagicMock, call, patch + +import pytest + +from mmrazor.engine.runner import SubnetValLoop + + +class TestSubnetValLoop(TestCase): + + def test_subnet_val_loop(self): + runner = MagicMock() + runner.distributed = False + runner.model = MagicMock() + dataloader = MagicMock() + evaluator = [MagicMock()] + fix_subnet_kinds = ['max', 'min', 'random'] + loop = SubnetValLoop( + runner, dataloader, evaluator, fix_subnet_kinds=fix_subnet_kinds) + + loop.estimator = MagicMock() + loop.estimator.estimate.return_value = dict(flops=10) + + runner.train_dataloader = MagicMock() + with patch.object(loop, '_evaluate_once') as evaluate_mock: + evaluate_mock.return_value = dict(acc=10) + all_metrics = dict() + all_metrics['max_subnet.acc'] = 10 + all_metrics['min_subnet.acc'] = 10 + all_metrics['random_subnet.acc'] = 10 + loop.run() + loop.runner.call_hook.assert_has_calls([ + call('before_val'), + call('before_val_epoch'), + call('after_val_epoch', metrics=all_metrics), + call('after_val') + ]) + evaluate_mock.assert_has_calls([call(), call(), call()]) + + runner.dataloader = MagicMock() + runner.dataloader.dataset = MagicMock() + loop.dataloader.__iter__.return_value = ['data_batch1'] + with patch.object(loop, + 'calibrate_bn_statistics') as calibration_bn_mock: + with patch.object(loop, 'run_iter') as run_iter_mock: + eval_result = dict(acc=10) + loop.evaluator.evaluate.return_value = eval_result + result = loop._evaluate_once() + calibration_bn_mock.assert_called_with( + runner.train_dataloader, loop.calibrate_sample_num) + runner.model.eval.assert_called() + run_iter_mock.assert_called_with(0, 'data_batch1') + loop.evaluator.evaluate.assert_called_with( + len(runner.dataloader.dataset)) + assert result == eval_result + loop.estimator.estimate.assert_called() + + def test_invalid_kind(self): + runner = MagicMock() + runner.distributed = False + runner.model = MagicMock() + dataloader = MagicMock() + evaluator = [MagicMock()] + fix_subnet_kinds = ['invalid'] + loop = SubnetValLoop( + runner, + dataloader, + evaluator, + fix_subnet_kinds=fix_subnet_kinds, + estimator_cfg=None) + with pytest.raises(NotImplementedError): + loop.run() + + def test_subnet_val_loop_with_invalid_value(self): + runner = MagicMock() + runner.model.module = MagicMock() + runner.model.module.__setattr__('sample_kinds', None) + del runner.model.module.sample_kinds + dataloader = MagicMock() + evaluator = [MagicMock()] + fix_subnet_kinds = [] + with pytest.raises(ValueError): + SubnetValLoop( + runner, + dataloader, + evaluator, + fix_subnet_kinds=fix_subnet_kinds, + estimator_cfg=None)