13
13
# limitations under the License.
14
14
15
15
import importlib .util
16
+ import traceback
16
17
from typing import Dict
17
18
from typing import Tuple
18
19
from typing import Union
19
20
21
+ import paddle
22
+
20
23
from ppsci .utils import logger
21
24
22
25
@@ -26,18 +29,7 @@ def run_check() -> None:
26
29
27
30
Examples:
28
31
>>> import ppsci
29
- >>> ppsci.utils.run_check()
30
- Runing test code [1/2] [1/5]
31
- Runing test code [1/2] [2/5]
32
- Runing test code [1/2] [3/5]
33
- Runing test code [1/2] [4/5]
34
- Runing test code [1/2] [5/5]
35
- Runing test code [2/2] [1/5]
36
- Runing test code [2/2] [2/5]
37
- Runing test code [2/2] [3/5]
38
- Runing test code [2/2] [4/5]
39
- Runing test code [2/2] [5/5]
40
- PaddleScience is installed successfully.✨ 🍰 ✨
32
+ >>> ppsci.utils.run_check() # doctest: +SKIP
41
33
"""
42
34
43
35
# test demo code below.
@@ -46,24 +38,26 @@ def run_check() -> None:
46
38
import ppsci
47
39
48
40
try :
49
- model = ppsci .arch .MLP (("x" , "y" ), ("u" , "v" , "p" ), 9 , 50 , "tanh" , False , False )
41
+ ppsci .utils .set_random_seed (42 )
42
+ ppsci .utils .logger .init_logger ()
43
+ model = ppsci .arch .MLP (("x" , "y" ), ("u" , "v" , "p" ), 3 , 16 , "tanh" )
50
44
51
45
equation = {"NavierStokes" : ppsci .equation .NavierStokes (0.01 , 1.0 , 2 , False )}
52
46
53
47
geom = {"rect" : ppsci .geometry .Rectangle ((- 0.05 , - 0.05 ), (0.05 , 0.05 ))}
54
48
55
- iters_per_epoch = 5
49
+ ITERS_PER_EPOCH = 5
56
50
train_dataloader_cfg = {
57
51
"dataset" : "IterableNamedArrayDataset" ,
58
- "iters_per_epoch" : iters_per_epoch ,
52
+ "iters_per_epoch" : ITERS_PER_EPOCH ,
59
53
}
60
54
61
- npoint_pde = 8 ** 2
55
+ NPOINT_PDE = 8 ** 2
62
56
pde_constraint = ppsci .constraint .InteriorConstraint (
63
57
equation ["NavierStokes" ].equations ,
64
58
{"continuity" : 0 , "momentum_x" : 0 , "momentum_y" : 0 },
65
59
geom ["rect" ],
66
- {** train_dataloader_cfg , "batch_size" : npoint_pde },
60
+ {** train_dataloader_cfg , "batch_size" : NPOINT_PDE },
67
61
ppsci .loss .MSELoss ("sum" ),
68
62
evenly = True ,
69
63
weight_dict = {
@@ -73,31 +67,43 @@ def run_check() -> None:
73
67
},
74
68
name = "EQ" ,
75
69
)
70
+ constraint = {pde_constraint .name : pde_constraint }
71
+
72
+ residual_validator = ppsci .validate .GeometryValidator (
73
+ equation ["NavierStokes" ].equations ,
74
+ {"continuity" : 0 , "momentum_x" : 0 , "momentum_y" : 0 },
75
+ geom ["rect" ],
76
+ {
77
+ "dataset" : "NamedArrayDataset" ,
78
+ "total_size" : 8 ** 2 ,
79
+ "batch_size" : 32 ,
80
+ "sampler" : {"name" : "BatchSampler" },
81
+ },
82
+ ppsci .loss .MSELoss ("sum" ),
83
+ evenly = True ,
84
+ metric = {"MSE" : ppsci .metric .MSE (False )},
85
+ name = "Residual" ,
86
+ )
87
+ validator = {residual_validator .name : residual_validator }
76
88
77
- epochs = 2
89
+ EPOCHS = 2
78
90
optimizer = ppsci .optimizer .Adam (0.001 )((model ,))
79
- for _epoch in range (1 , epochs + 1 ):
80
- for _iter_id in range (1 , iters_per_epoch + 1 ):
81
- input_dict , label_dict , weight_dict = next (pde_constraint .data_iter )
82
- for v in input_dict .values ():
83
- v .stop_gradient = False
84
- evaluator = ppsci .utils .ExpressionSolver (
85
- pde_constraint .input_keys , pde_constraint .output_keys , model
86
- )
87
- for output_name , output_formula in pde_constraint .output_expr .items ():
88
- if output_name in label_dict :
89
- evaluator .add_target_expr (output_formula , output_name )
90
-
91
- output_dict = evaluator (input_dict )
92
- loss = pde_constraint .loss (output_dict , label_dict , weight_dict )
93
- loss .backward ()
94
- optimizer .step ()
95
- optimizer .clear_grad ()
96
- print (
97
- f"Runing test code [{ _epoch } /{ epochs } ]"
98
- f" [{ _iter_id } /{ iters_per_epoch } ]"
99
- )
91
+ solver = ppsci .solver .Solver (
92
+ model ,
93
+ constraint ,
94
+ None ,
95
+ optimizer ,
96
+ None ,
97
+ EPOCHS ,
98
+ ITERS_PER_EPOCH ,
99
+ device = paddle .device .get_device (),
100
+ equation = equation ,
101
+ validator = validator ,
102
+ )
103
+ solver .train ()
104
+ solver .eval (EPOCHS )
100
105
except Exception as e :
106
+ traceback .print_exc ()
101
107
logging .warning (
102
108
f"PaddleScience meets some problem with \n { repr (e )} \n please check whether "
103
109
"Paddle's version and PaddleScience's version are both correct."
0 commit comments