@@ -21,8 +21,8 @@ class SubnetValLoop(ValLoop, CalibrateBNMixin):
21
21
evaluator (Evaluator or dict or list): Used for computing metrics.
22
22
fp16 (bool): Whether to enable fp16 validation. Defaults to
23
23
False.
24
- evaluate_fixed_subnet (bool ): Whether to evaluate a fixed subnet only
25
- or not. Defaults to False.
24
+ fix_subnet_kind (str ): fix subnet kinds when evaluate, this would be
25
+ `sample_kinds` if not specified
26
26
calibrate_sample_num (int): The number of images to compute the true
27
27
average of per-batch mean/variance instead of the running average.
28
28
Defaults to 4096.
@@ -36,7 +36,7 @@ def __init__(
36
36
dataloader : Union [DataLoader , Dict ],
37
37
evaluator : Union [Evaluator , Dict , List ],
38
38
fp16 : bool = False ,
39
- evaluate_fixed_subnet : bool = False ,
39
+ fix_subnet_kinds : List [ str ] = [] ,
40
40
calibrate_sample_num : int = 4096 ,
41
41
estimator_cfg : Optional [Dict ] = dict (type = 'mmrazor.ResourceEstimator' )
42
42
) -> None :
@@ -48,9 +48,18 @@ def __init__(
48
48
model = self .runner .model
49
49
50
50
self .model = model
51
- self .evaluate_fixed_subnet = evaluate_fixed_subnet
51
+ if len (fix_subnet_kinds ) == 0 and not hasattr (self .model ,
52
+ 'sample_kinds' ):
53
+ raise ValueError (
54
+ 'neither fix_subnet_kinds nor self.model.sample_kinds exists' )
55
+
56
+ self .evaluate_kinds = fix_subnet_kinds if len (
57
+ fix_subnet_kinds ) > 0 else getattr (self .model , 'sample_kinds' )
58
+
52
59
self .calibrate_sample_num = calibrate_sample_num
53
- self .estimator = TASK_UTILS .build (estimator_cfg )
60
+ self .estimator = None
61
+ if estimator_cfg :
62
+ self .estimator = TASK_UTILS .build (estimator_cfg )
54
63
55
64
def run (self ):
56
65
"""Launch validation."""
@@ -59,24 +68,19 @@ def run(self):
59
68
60
69
all_metrics = dict ()
61
70
62
- if self .evaluate_fixed_subnet :
71
+ for kind in self .evaluate_kinds :
72
+ if kind == 'max' :
73
+ self .model .mutator .set_max_choices ()
74
+ elif kind == 'min' :
75
+ self .model .mutator .set_min_choices ()
76
+ elif 'random' in kind :
77
+ self .model .mutator .set_choices (
78
+ self .model .mutator .sample_choices ())
79
+ else :
80
+ raise NotImplementedError (f'Unsupported Subnet { kind } ' )
81
+
63
82
metrics = self ._evaluate_once ()
64
- all_metrics .update (add_prefix (metrics , 'fix_subnet' ))
65
- elif hasattr (self .model , 'sample_kinds' ):
66
- for kind in self .model .sample_kinds :
67
- if kind == 'max' :
68
- self .model .mutator .set_max_choices ()
69
- metrics = self ._evaluate_once ()
70
- all_metrics .update (add_prefix (metrics , 'max_subnet' ))
71
- elif kind == 'min' :
72
- self .model .mutator .set_min_choices ()
73
- metrics = self ._evaluate_once ()
74
- all_metrics .update (add_prefix (metrics , 'min_subnet' ))
75
- elif 'random' in kind :
76
- self .model .mutator .set_choices (
77
- self .model .mutator .sample_choices ())
78
- metrics = self ._evaluate_once ()
79
- all_metrics .update (add_prefix (metrics , f'{ kind } _subnet' ))
83
+ all_metrics .update (add_prefix (metrics , f'{ kind } _subnet' ))
80
84
81
85
self .runner .call_hook ('after_val_epoch' , metrics = all_metrics )
82
86
self .runner .call_hook ('after_val' )
@@ -90,7 +94,8 @@ def _evaluate_once(self) -> Dict:
90
94
self .run_iter (idx , data_batch )
91
95
92
96
metrics = self .evaluator .evaluate (len (self .dataloader .dataset ))
93
- resource_metrics = self .estimator .estimate (self .model )
94
- metrics .update (resource_metrics )
97
+ if self .estimator :
98
+ resource_metrics = self .estimator .estimate (self .model )
99
+ metrics .update (resource_metrics )
95
100
96
101
return metrics
0 commit comments