Skip to content

Commit 3f8bfeb

Browse files
fix typehint for Solver.predict
1 parent be29e16 commit 3f8bfeb

File tree

1 file changed

+37
-2
lines changed

1 file changed

+37
-2
lines changed

ppsci/solver/solver.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
from typing import Any
2222
from typing import Dict
2323
from typing import Optional
24+
from typing import Union
2425

26+
import numpy as np
2527
import paddle
2628
import paddle.distributed as dist
2729
import visualdl as vdl
@@ -67,6 +69,9 @@ class Solver:
6769
amp_level (Literal["O1", "O2", "O0"], optional): AMP level. Defaults to "O0".
6870
pretrained_model_path (Optional[str]): Pretrained model path. Defaults to None.
6971
checkpoint_path (Optional[str]): Checkpoint path. Defaults to None.
72+
compute_metric_by_batch (bool, optional): Whether calculate metrics after each batch during evaluate. Defaults to False.
73+
eval_with_no_grad (bool, optional): Whether set `stop_gradient=True` for every Tensor if no differentiation
74+
involved during computation, generally for save GPU memory and accelerate computing. Defaults to False.
7075
7176
Examples:
7277
>>> import ppsci
@@ -120,6 +125,8 @@ def __init__(
120125
amp_level: Literal["O1", "O2", "O0"] = "O0",
121126
pretrained_model_path: Optional[str] = None,
122127
checkpoint_path: Optional[str] = None,
128+
compute_metric_by_batch: bool = False,
129+
eval_with_no_grad: bool = False,
123130
):
124131
# set model
125132
self.model = model
@@ -190,6 +197,11 @@ def __init__(
190197
if pretrained_model_path is not None:
191198
save_load.load_pretrain(self.model, pretrained_model_path, self.equation)
192199

200+
# whether calculate metrics after each batch during evaluate
201+
self.compute_metric_by_batch = compute_metric_by_batch
202+
# whether set `stop_gradient=True` for every Tensor if no differentiation involved during computation
203+
self.eval_with_no_grad = eval_with_no_grad
204+
193205
# initialize an dict for tracking best metric during training
194206
self.best_metric = {
195207
"metric": float("inf"),
@@ -291,6 +303,8 @@ def from_config(cfg: Dict[str, Any]) -> Solver:
291303
update_freq = cfg["Global"].get("update_freq", 1)
292304
pretrained_model_path = cfg["Global"].get("pretrained_model_path", None)
293305
checkpoint_path = cfg["Global"].get("checkpoint_path", None)
306+
compute_metric_by_batch = cfg["Global"].get("compute_metric_by_batch", False)
307+
eval_with_no_grad = cfg["Global"].get("eval_with_no_grad", False)
294308

295309
return Solver(
296310
model,
@@ -317,6 +331,8 @@ def from_config(cfg: Dict[str, Any]) -> Solver:
317331
amp_level,
318332
pretrained_model_path,
319333
checkpoint_path,
334+
compute_metric_by_batch,
335+
eval_with_no_grad,
320336
)
321337

322338
def train(self):
@@ -430,12 +446,14 @@ def visualize(self, epoch_id: int = 0):
430446

431447
@paddle.no_grad()
432448
def predict(
433-
self, input_dict: Dict[str, paddle.Tensor], batch_size: int = 64
449+
self,
450+
input_dict: Dict[str, Union[np.ndarray, paddle.Tensor]],
451+
batch_size: int = 64,
434452
) -> Dict[str, paddle.Tensor]:
435453
"""Pure prediction using model.forward(...), support single device prediction yet.
436454
437455
Args:
438-
input_dict (Dict[str, paddle.Tensor]): Input data in dict.
456+
input_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Input data in dict.
439457
batch_size (int, optional): Predicting by batch size. Defaults to 64.
440458
441459
Returns:
@@ -516,3 +534,20 @@ def autocast_context_manager(self) -> contextlib.AbstractContextManager:
516534
)
517535

518536
return ctx_manager
537+
538+
def no_grad_context_manager(self) -> contextlib.AbstractContextManager:
539+
"""No grad manager.
540+
541+
Returns:
542+
Union[contextlib.AbstractContextManager]: Context manager.
543+
"""
544+
if self.eval_with_no_grad:
545+
ctx_manager = paddle.no_grad()
546+
else:
547+
ctx_manager = (
548+
contextlib.nullcontext()
549+
if sys.version_info >= (3, 7)
550+
else contextlib.suppress()
551+
)
552+
553+
return ctx_manager

0 commit comments

Comments
 (0)