Skip to content

Commit 3fe8b03

Browse files
refine type hint of train/eval/visu
1 parent 3ef1adb commit 3fe8b03

File tree

4 files changed

+19
-11
lines changed

4 files changed

+19
-11
lines changed

ppsci/solver/eval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818
import paddle.amp as amp
1919
import paddle.io as io
2020

21+
from ppsci import solver
2122
from ppsci.solver import printer
2223
from ppsci.utils import expression
2324
from ppsci.utils import misc
2425
from ppsci.utils import profiler
2526

2627

27-
def eval_func(solver, epoch_id, log_freq) -> float:
28+
def eval_func(solver: solver.Solver, epoch_id: int, log_freq: int) -> float:
2829
"""Evaluation program
2930
3031
Args:

ppsci/solver/solver.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def train(self):
392392
if self.vdl_writer is not None:
393393
self.vdl_writer.close()
394394

395-
def eval(self, epoch_id=0):
395+
def eval(self, epoch_id: int = 0):
396396
"""Evaluation"""
397397
train_state = self.model.training
398398
if train_state:
@@ -412,7 +412,7 @@ def eval(self, epoch_id=0):
412412
self.model.train()
413413
return result
414414

415-
def visualize(self, epoch_id=0):
415+
def visualize(self, epoch_id: int = 0):
416416
"""Visualization"""
417417
train_state = self.model.training
418418
if train_state:
@@ -461,7 +461,9 @@ def predict(
461461
# prepare batch input dict
462462
for key in input_dict:
463463
if not paddle.is_tensor(input_dict[key]):
464-
batch_input_dict[key] = paddle.to_tensor(input_dict[key][st:ed])
464+
batch_input_dict[key] = paddle.to_tensor(
465+
input_dict[key][st:ed], paddle.get_default_dtype()
466+
)
465467
else:
466468
batch_input_dict[key] = input_dict[key][st:ed]
467469
batch_input_dict[key].stop_gradient = False

ppsci/solver/train.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,18 @@
1616

1717
import paddle.amp as amp
1818

19+
from ppsci import solver
1920
from ppsci.solver import printer
2021
from ppsci.utils import expression
2122
from ppsci.utils import misc
2223
from ppsci.utils import profiler
2324

2425

25-
def train_epoch_func(solver, epoch_id, log_freq):
26+
def train_epoch_func(solver: solver.Solver, epoch_id: int, log_freq: int):
2627
"""Train program for one epoch
2728
2829
Args:
29-
solver (Solver): Main solver.
30+
solver (solver.Solver): Main solver.
3031
epoch_id (int): Epoch id.
3132
log_freq (int): Log training information every `log_freq` steps.
3233
"""
@@ -112,11 +113,11 @@ def train_epoch_func(solver, epoch_id, log_freq):
112113
batch_tic = time.perf_counter()
113114

114115

115-
def train_LBFGS_epoch_func(solver, epoch_id, log_freq):
116+
def train_LBFGS_epoch_func(solver: solver.Solver, epoch_id: int, log_freq: int):
116117
"""Train function for one epoch with L-BFGS optimizer.
117118
118119
Args:
119-
solver (Solver): Main solver.
120+
solver (solver.Solver): Main solver.
120121
epoch_id (int): Epoch id.
121122
log_freq (int): Log training information every `log_freq` steps.
122123
"""

ppsci/solver/visu.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,17 @@
1818
import paddle
1919
import paddle.amp as amp
2020

21+
from ppsci import solver
2122
from ppsci.utils import expression
2223
from ppsci.utils import misc
2324

2425

25-
def visualize_func(solver, epoch_id):
26+
@paddle.no_grad()
27+
def visualize_func(solver: solver.Solver, epoch_id: int):
2628
"""Visualization program
2729
2830
Args:
29-
solver (Solver): Main Solver.
31+
solver (solver.Solver): Main Solver.
3032
epoch_id (int): Epoch id.
3133
3234
Returns:
@@ -49,7 +51,9 @@ def visualize_func(solver, epoch_id):
4951
# prepare batch input dict
5052
for key in input_dict:
5153
if not paddle.is_tensor(input_dict[key]):
52-
batch_input_dict[key] = paddle.to_tensor(input_dict[key][st:ed])
54+
batch_input_dict[key] = paddle.to_tensor(
55+
input_dict[key][st:ed], paddle.get_default_dtype()
56+
)
5357
else:
5458
batch_input_dict[key] = input_dict[key][st:ed]
5559
batch_input_dict[key].stop_gradient = False

0 commit comments

Comments
 (0)