@@ -21,8 +21,8 @@ class SubnetValLoop(ValLoop, CalibrateBNMixin):
21
21
evaluator (Evaluator or dict or list): Used for computing metrics.
22
22
fp16 (bool): Whether to enable fp16 validation. Defaults to
23
23
False.
24
- evaluate_fixed_subnet (bool ): Whether to evaluate a fixed subnet only
25
- or not. Defaults to False.
24
+ fix_subnet_kind (str ): fix subnet kinds when evaluate, this would be
25
+ `sample_kinds` if not specified
26
26
calibrate_sample_num (int): The number of images to compute the true
27
27
average of per-batch mean/variance instead of the running average.
28
28
Defaults to 4096.
@@ -36,7 +36,7 @@ def __init__(
36
36
dataloader : Union [DataLoader , Dict ],
37
37
evaluator : Union [Evaluator , Dict , List ],
38
38
fp16 : bool = False ,
39
- evaluate_fixed_subnet : bool = False ,
39
+ fix_subnet_kinds : List [ str ] = [] ,
40
40
calibrate_sample_num : int = 4096 ,
41
41
estimator_cfg : Optional [Dict ] = dict (type = 'mmrazor.ResourceEstimator' )
42
42
) -> None :
@@ -48,9 +48,13 @@ def __init__(
48
48
model = self .runner .model
49
49
50
50
self .model = model
51
- self .evaluate_fixed_subnet = evaluate_fixed_subnet
51
+ self .evaluate_kinds = fix_subnet_kinds if len (
52
+ fix_subnet_kinds ) > 0 else getattr (self .model , 'sample_kinds' )
53
+
52
54
self .calibrate_sample_num = calibrate_sample_num
53
- self .estimator = TASK_UTILS .build (estimator_cfg )
55
+ self .estimator = None
56
+ if estimator_cfg :
57
+ self .estimator = TASK_UTILS .build (estimator_cfg )
54
58
55
59
def run (self ):
56
60
"""Launch validation."""
@@ -59,24 +63,19 @@ def run(self):
59
63
60
64
all_metrics = dict ()
61
65
62
- if self .evaluate_fixed_subnet :
66
+ for kind in self .evaluate_kinds :
67
+ if kind == 'max' :
68
+ self .model .mutator .set_max_choices ()
69
+ elif kind == 'min' :
70
+ self .model .mutator .set_min_choices ()
71
+ elif 'random' in kind :
72
+ self .model .mutator .set_choices (
73
+ self .model .mutator .sample_choices ())
74
+ else :
75
+ raise NotImplementedError (f'{ kind } ' )
76
+
63
77
metrics = self ._evaluate_once ()
64
- all_metrics .update (add_prefix (metrics , 'fix_subnet' ))
65
- elif hasattr (self .model , 'sample_kinds' ):
66
- for kind in self .model .sample_kinds :
67
- if kind == 'max' :
68
- self .model .set_max_subnet ()
69
- metrics = self ._evaluate_once ()
70
- all_metrics .update (add_prefix (metrics , 'max_subnet' ))
71
- elif kind == 'min' :
72
- self .model .set_min_subnet ()
73
- metrics = self ._evaluate_once ()
74
- all_metrics .update (add_prefix (metrics , 'min_subnet' ))
75
- elif 'random' in kind :
76
- self .model .mutator .set_choices (
77
- self .model .mutator .sample_choices ())
78
- metrics = self ._evaluate_once ()
79
- all_metrics .update (add_prefix (metrics , f'{ kind } _subnet' ))
78
+ all_metrics .update (add_prefix (metrics , f'{ kind } _subnet' ))
80
79
81
80
self .runner .call_hook ('after_val_epoch' , metrics = all_metrics )
82
81
self .runner .call_hook ('after_val' )
@@ -90,7 +89,8 @@ def _evaluate_once(self) -> Dict:
90
89
self .run_iter (idx , data_batch )
91
90
92
91
metrics = self .evaluator .evaluate (len (self .dataloader .dataset ))
93
- resource_metrics = self .estimator .estimate (self .model )
94
- metrics .update (resource_metrics )
92
+ if self .estimator :
93
+ resource_metrics = self .estimator .estimate (self .model )
94
+ metrics .update (resource_metrics )
95
95
96
96
return metrics
0 commit comments