Skip to content

[Improvement] Update SubnetValLoop #465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 29 additions & 24 deletions mmrazor/engine/runner/subnet_val_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class SubnetValLoop(ValLoop, CalibrateBNMixin):
evaluator (Evaluator or dict or list): Used for computing metrics.
fp16 (bool): Whether to enable fp16 validation. Defaults to
False.
evaluate_fixed_subnet (bool): Whether to evaluate a fixed subnet only
or not. Defaults to False.
fix_subnet_kind (str): fix subnet kinds when evaluate, this would be
`sample_kinds` if not specified
calibrate_sample_num (int): The number of images to compute the true
average of per-batch mean/variance instead of the running average.
Defaults to 4096.
Expand All @@ -36,7 +36,7 @@ def __init__(
dataloader: Union[DataLoader, Dict],
evaluator: Union[Evaluator, Dict, List],
fp16: bool = False,
evaluate_fixed_subnet: bool = False,
fix_subnet_kinds: List[str] = [],
calibrate_sample_num: int = 4096,
estimator_cfg: Optional[Dict] = dict(type='mmrazor.ResourceEstimator')
) -> None:
Expand All @@ -48,9 +48,18 @@ def __init__(
model = self.runner.model

self.model = model
self.evaluate_fixed_subnet = evaluate_fixed_subnet
if len(fix_subnet_kinds) == 0 and not hasattr(self.model,
'sample_kinds'):
raise ValueError(
'neither fix_subnet_kinds nor self.model.sample_kinds exists')

self.evaluate_kinds = fix_subnet_kinds if len(
fix_subnet_kinds) > 0 else getattr(self.model, 'sample_kinds')

self.calibrate_sample_num = calibrate_sample_num
self.estimator = TASK_UTILS.build(estimator_cfg)
self.estimator = None
if estimator_cfg:
self.estimator = TASK_UTILS.build(estimator_cfg)

def run(self):
"""Launch validation."""
Expand All @@ -59,24 +68,19 @@ def run(self):

all_metrics = dict()

if self.evaluate_fixed_subnet:
for kind in self.evaluate_kinds:
if kind == 'max':
self.model.mutator.set_max_choices()
elif kind == 'min':
self.model.mutator.set_min_choices()
elif 'random' in kind:
self.model.mutator.set_choices(
self.model.mutator.sample_choices())
else:
raise NotImplementedError(f'Unsupported Subnet {kind}')

metrics = self._evaluate_once()
all_metrics.update(add_prefix(metrics, 'fix_subnet'))
elif hasattr(self.model, 'sample_kinds'):
for kind in self.model.sample_kinds:
if kind == 'max':
self.model.mutator.set_max_choices()
metrics = self._evaluate_once()
all_metrics.update(add_prefix(metrics, 'max_subnet'))
elif kind == 'min':
self.model.mutator.set_min_choices()
metrics = self._evaluate_once()
all_metrics.update(add_prefix(metrics, 'min_subnet'))
elif 'random' in kind:
self.model.mutator.set_choices(
self.model.mutator.sample_choices())
metrics = self._evaluate_once()
all_metrics.update(add_prefix(metrics, f'{kind}_subnet'))
all_metrics.update(add_prefix(metrics, f'{kind}_subnet'))

self.runner.call_hook('after_val_epoch', metrics=all_metrics)
self.runner.call_hook('after_val')
Expand All @@ -90,7 +94,8 @@ def _evaluate_once(self) -> Dict:
self.run_iter(idx, data_batch)

metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
resource_metrics = self.estimator.estimate(self.model)
metrics.update(resource_metrics)
if self.estimator:
resource_metrics = self.estimator.estimate(self.model)
metrics.update(resource_metrics)

return metrics
89 changes: 89 additions & 0 deletions tests/test_engine/test_runner/test_subnet_val_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import MagicMock, call, patch

import pytest

from mmrazor.engine.runner import SubnetValLoop


class TestSubnetValLoop(TestCase):

def test_subnet_val_loop(self):
runner = MagicMock()
runner.distributed = False
runner.model = MagicMock()
dataloader = MagicMock()
evaluator = [MagicMock()]
fix_subnet_kinds = ['max', 'min', 'random']
loop = SubnetValLoop(
runner, dataloader, evaluator, fix_subnet_kinds=fix_subnet_kinds)

loop.estimator = MagicMock()
loop.estimator.estimate.return_value = dict(flops=10)

runner.train_dataloader = MagicMock()
with patch.object(loop, '_evaluate_once') as evaluate_mock:
evaluate_mock.return_value = dict(acc=10)
all_metrics = dict()
all_metrics['max_subnet.acc'] = 10
all_metrics['min_subnet.acc'] = 10
all_metrics['random_subnet.acc'] = 10
loop.run()
loop.runner.call_hook.assert_has_calls([
call('before_val'),
call('before_val_epoch'),
call('after_val_epoch', metrics=all_metrics),
call('after_val')
])
evaluate_mock.assert_has_calls([call(), call(), call()])

runner.dataloader = MagicMock()
runner.dataloader.dataset = MagicMock()
loop.dataloader.__iter__.return_value = ['data_batch1']
with patch.object(loop,
'calibrate_bn_statistics') as calibration_bn_mock:
with patch.object(loop, 'run_iter') as run_iter_mock:
eval_result = dict(acc=10)
loop.evaluator.evaluate.return_value = eval_result
result = loop._evaluate_once()
calibration_bn_mock.assert_called_with(
runner.train_dataloader, loop.calibrate_sample_num)
runner.model.eval.assert_called()
run_iter_mock.assert_called_with(0, 'data_batch1')
loop.evaluator.evaluate.assert_called_with(
len(runner.dataloader.dataset))
assert result == eval_result
loop.estimator.estimate.assert_called()

def test_invalid_kind(self):
runner = MagicMock()
runner.distributed = False
runner.model = MagicMock()
dataloader = MagicMock()
evaluator = [MagicMock()]
fix_subnet_kinds = ['invalid']
loop = SubnetValLoop(
runner,
dataloader,
evaluator,
fix_subnet_kinds=fix_subnet_kinds,
estimator_cfg=None)
with pytest.raises(NotImplementedError):
loop.run()

def test_subnet_val_loop_with_invalid_value(self):
runner = MagicMock()
runner.model.module = MagicMock()
runner.model.module.__setattr__('sample_kinds', None)
del runner.model.module.sample_kinds
dataloader = MagicMock()
evaluator = [MagicMock()]
fix_subnet_kinds = []
with pytest.raises(ValueError):
SubnetValLoop(
runner,
dataloader,
evaluator,
fix_subnet_kinds=fix_subnet_kinds,
estimator_cfg=None)