Skip to content

Commit 8846a0e

Browse files
author
liukai
committed
add tests
1 parent 92b40cf commit 8846a0e

File tree

4 files changed

+122
-0
lines changed

4 files changed

+122
-0
lines changed

tests/test_models/test_algorithms/test_prune_algorithm.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from mmrazor.models.algorithms.pruning.ite_prune_algorithm import (
1212
ItePruneAlgorithm, ItePruneConfigManager)
1313
from mmrazor.registry import MODELS
14+
from projects.group_fisher.modules.group_fisher_algorthm import \
15+
GroupFisherAlgorithm
16+
from projects.group_fisher.modules.group_fisher_ops import GroupFisherConv2d
1417
from ...utils.set_dist_env import SetDistEnv
1518

1619

@@ -262,3 +265,63 @@ def test_resume(self):
262265
print(algorithm2.mutator.current_choices)
263266
self.assertDictEqual(algorithm.mutator.current_choices,
264267
algorithm2.mutator.current_choices)
268+
269+
270+
class TestGroupFisherPruneAlgorithm(TestItePruneAlgorithm):
271+
272+
def test_group_fisher_prune(self):
273+
data = self.fake_cifar_data()
274+
275+
MUTATOR_CONFIG = dict(
276+
type='GroupFisherChannelMutator',
277+
parse_cfg=dict(type='ChannelAnalyzer', tracer_type='FxTracer'),
278+
channel_unit_cfg=dict(type='GroupFisherChannelUnit'))
279+
280+
epoch = 2
281+
interval = 1
282+
delta = 'flops'
283+
284+
algorithm = GroupFisherAlgorithm(
285+
MODEL_CFG,
286+
pruning=True,
287+
mutator=MUTATOR_CONFIG,
288+
delta=delta,
289+
interval=interval,
290+
save_ckpt_delta_thr=[1.1]).to(DEVICE)
291+
mutator = algorithm.mutator
292+
293+
ckpt_path = os.path.dirname(__file__) + f'/{delta}_0.99.pth'
294+
295+
fake_cfg_path = os.path.dirname(__file__) + '/cfg.py'
296+
self.gen_fake_cfg(fake_cfg_path)
297+
self.assertTrue(os.path.exists(fake_cfg_path))
298+
299+
message_hub = MessageHub.get_current_instance()
300+
cfg_str = open(fake_cfg_path).read()
301+
message_hub.update_info('cfg', cfg_str)
302+
303+
for e in range(epoch):
304+
for ite in range(10):
305+
self._set_epoch_ite(e, ite, epoch)
306+
algorithm.forward(
307+
data['inputs'], data['data_samples'], mode='loss')
308+
self.gen_fake_grad(mutator)
309+
self.assertEqual(delta, algorithm.delta)
310+
self.assertEqual(interval, algorithm.interval)
311+
self.assertTrue(os.path.exists(ckpt_path))
312+
os.remove(ckpt_path)
313+
os.remove(fake_cfg_path)
314+
self.assertTrue(not os.path.exists(ckpt_path))
315+
self.assertTrue(not os.path.exists(fake_cfg_path))
316+
317+
def gen_fake_grad(self, mutator):
318+
for unit in mutator.mutable_units:
319+
for channel in unit.input_related:
320+
module = channel.module
321+
if isinstance(module, GroupFisherConv2d):
322+
module.recorded_grad = module.recorded_input
323+
324+
def gen_fake_cfg(self, fake_cfg_path):
325+
with open(fake_cfg_path, 'a', encoding='utf-8') as cfg:
326+
cfg.write(f'work_dir = \'{os.path.dirname(__file__)}\'')
327+
cfg.write('\n')

tests/test_projects/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import unittest
3+
4+
import torch
5+
6+
from mmrazor.models.mutables import SimpleMutableChannel
7+
from mmrazor.models.mutators import ChannelMutator
8+
from projects.cores.expandable_ops.ops import ExpandLinear
9+
from projects.cores.expandable_ops.unit import (ExpandableUnit, expand_model,
10+
expand_static_model)
11+
from ...data.models import MultiConcatModel, SingleLineModel
12+
13+
14+
class TestExpand(unittest.TestCase):
15+
16+
def test_expand(self):
17+
x = torch.rand([1, 3, 224, 224])
18+
model = MultiConcatModel()
19+
print(model)
20+
mutator = ChannelMutator[ExpandableUnit](
21+
channel_unit_cfg=ExpandableUnit)
22+
mutator.prepare_from_supernet(model)
23+
print(mutator.choice_template)
24+
print(model)
25+
y1 = model(x)
26+
27+
for unit in mutator.mutable_units:
28+
unit.expand(10)
29+
print(unit.mutable_channel.mask.shape)
30+
expand_model(model, zero=True)
31+
print(model)
32+
y2 = model(x)
33+
self.assertTrue((y1 - y2).abs().max() < 1e-3)
34+
35+
def test_expand_static_model(self):
36+
x = torch.rand([1, 3, 224, 224])
37+
model = SingleLineModel()
38+
y1 = model(x)
39+
expand_static_model(model, divisor=4)
40+
y2 = model(x)
41+
print(y1.reshape([-1])[:5])
42+
print(y2.reshape([-1])[:5])
43+
self.assertTrue((y1 - y2).abs().max() < 1e-3)
44+
45+
def test_ExpandConv2d(self):
46+
linear = ExpandLinear(3, 3)
47+
mutable_in = SimpleMutableChannel(3)
48+
mutable_out = SimpleMutableChannel(3)
49+
linear.register_mutable_attr('in_channels', mutable_in)
50+
linear.register_mutable_attr('out_channels', mutable_out)
51+
52+
print(linear.weight)
53+
54+
mutable_in.mask = torch.tensor([1.0, 1.0, 0.0, 1.0, 0.0])
55+
mutable_out.mask = torch.tensor([1.0, 1.0, 0.0, 1.0, 0.0])
56+
linear_ex = linear.expand(zero=True)
57+
print(linear_ex.weight)

0 commit comments

Comments
 (0)