21
21
from typing import Any
22
22
from typing import Dict
23
23
from typing import Optional
24
+ from typing import Union
24
25
26
+ import numpy as np
25
27
import paddle
26
28
import paddle .distributed as dist
27
29
import visualdl as vdl
@@ -67,6 +69,9 @@ class Solver:
67
69
amp_level (Literal["O1", "O2", "O0"], optional): AMP level. Defaults to "O0".
68
70
pretrained_model_path (Optional[str]): Pretrained model path. Defaults to None.
69
71
checkpoint_path (Optional[str]): Checkpoint path. Defaults to None.
72
+ compute_metric_by_batch (bool, optional): Whether calculate metrics after each batch during evaluate. Defaults to False.
73
+ eval_with_no_grad (bool, optional): Whether set `stop_gradient=True` for every Tensor if no differentiation
74
+ involved during computation, generally for save GPU memory and accelerate computing. Defaults to False.
70
75
71
76
Examples:
72
77
>>> import ppsci
@@ -120,6 +125,8 @@ def __init__(
120
125
amp_level : Literal ["O1" , "O2" , "O0" ] = "O0" ,
121
126
pretrained_model_path : Optional [str ] = None ,
122
127
checkpoint_path : Optional [str ] = None ,
128
+ compute_metric_by_batch : bool = False ,
129
+ eval_with_no_grad : bool = False ,
123
130
):
124
131
# set model
125
132
self .model = model
@@ -190,6 +197,11 @@ def __init__(
190
197
if pretrained_model_path is not None :
191
198
save_load .load_pretrain (self .model , pretrained_model_path , self .equation )
192
199
200
+ # whether calculate metrics after each batch during evaluate
201
+ self .compute_metric_by_batch = compute_metric_by_batch
202
+ # whether set `stop_gradient=True` for every Tensor if no differentiation involved during computation
203
+ self .eval_with_no_grad = eval_with_no_grad
204
+
193
205
# initialize an dict for tracking best metric during training
194
206
self .best_metric = {
195
207
"metric" : float ("inf" ),
@@ -291,6 +303,8 @@ def from_config(cfg: Dict[str, Any]) -> Solver:
291
303
update_freq = cfg ["Global" ].get ("update_freq" , 1 )
292
304
pretrained_model_path = cfg ["Global" ].get ("pretrained_model_path" , None )
293
305
checkpoint_path = cfg ["Global" ].get ("checkpoint_path" , None )
306
+ compute_metric_by_batch = cfg ["Global" ].get ("compute_metric_by_batch" , False )
307
+ eval_with_no_grad = cfg ["Global" ].get ("eval_with_no_grad" , False )
294
308
295
309
return Solver (
296
310
model ,
@@ -317,6 +331,8 @@ def from_config(cfg: Dict[str, Any]) -> Solver:
317
331
amp_level ,
318
332
pretrained_model_path ,
319
333
checkpoint_path ,
334
+ compute_metric_by_batch ,
335
+ eval_with_no_grad ,
320
336
)
321
337
322
338
def train (self ):
@@ -430,12 +446,14 @@ def visualize(self, epoch_id: int = 0):
430
446
431
447
@paddle .no_grad ()
432
448
def predict (
433
- self , input_dict : Dict [str , paddle .Tensor ], batch_size : int = 64
449
+ self ,
450
+ input_dict : Dict [str , Union [np .ndarray , paddle .Tensor ]],
451
+ batch_size : int = 64 ,
434
452
) -> Dict [str , paddle .Tensor ]:
435
453
"""Pure prediction using model.forward(...), support single device prediction yet.
436
454
437
455
Args:
438
- input_dict (Dict[str, paddle.Tensor]): Input data in dict.
456
+ input_dict (Dict[str, Union[np.ndarray, paddle.Tensor] ]): Input data in dict.
439
457
batch_size (int, optional): Predicting by batch size. Defaults to 64.
440
458
441
459
Returns:
@@ -516,3 +534,20 @@ def autocast_context_manager(self) -> contextlib.AbstractContextManager:
516
534
)
517
535
518
536
return ctx_manager
537
+
538
+ def no_grad_context_manager (self ) -> contextlib .AbstractContextManager :
539
+ """No grad manager.
540
+
541
+ Returns:
542
+ Union[contextlib.AbstractContextManager]: Context manager.
543
+ """
544
+ if self .eval_with_no_grad :
545
+ ctx_manager = paddle .no_grad ()
546
+ else :
547
+ ctx_manager = (
548
+ contextlib .nullcontext ()
549
+ if sys .version_info >= (3 , 7 )
550
+ else contextlib .suppress ()
551
+ )
552
+
553
+ return ctx_manager
0 commit comments