24
24
from ppsci .utils import logger
25
25
26
26
27
+ def plot (input_data , N_EVAL , pinn_output , fdm_output , cfg ):
28
+ x = input_data ["x" ].reshape (N_EVAL , N_EVAL )
29
+ y = input_data ["y" ].reshape (N_EVAL , N_EVAL )
30
+
31
+ plt .subplot (2 , 1 , 1 )
32
+ plt .pcolormesh (x , y , pinn_output * 75.0 , cmap = "magma" )
33
+ plt .colorbar ()
34
+ plt .title ("PINN" )
35
+ plt .xlabel ("x" )
36
+ plt .ylabel ("y" )
37
+ plt .tight_layout ()
38
+ plt .axis ("square" )
39
+
40
+ plt .subplot (2 , 1 , 2 )
41
+ plt .pcolormesh (x , y , fdm_output , cmap = "magma" )
42
+ plt .colorbar ()
43
+ plt .xlabel ("x" )
44
+ plt .ylabel ("y" )
45
+ plt .title ("FDM" )
46
+ plt .tight_layout ()
47
+ plt .axis ("square" )
48
+ plt .savefig (osp .join (cfg .output_dir , "pinn_fdm_comparison.png" ))
49
+ plt .close ()
50
+
51
+ frames_val = np .array ([- 0.75 , - 0.5 , - 0.25 , 0.0 , + 0.25 , + 0.5 , + 0.75 ])
52
+ frames = [* map (int , (frames_val + 1 ) / 2 * (N_EVAL - 1 ))]
53
+ height = 3
54
+ plt .figure ("" , figsize = (len (frames ) * height , 2 * height ))
55
+
56
+ for i , var_index in enumerate (frames ):
57
+ plt .subplot (2 , len (frames ), i + 1 )
58
+ plt .title (f"y = { frames_val [i ]:.2f} " )
59
+ plt .plot (
60
+ x [:, var_index ],
61
+ pinn_output [:, var_index ] * 75.0 ,
62
+ "r--" ,
63
+ lw = 4.0 ,
64
+ label = "pinn" ,
65
+ )
66
+ plt .plot (x [:, var_index ], fdm_output [:, var_index ], "b" , lw = 2.0 , label = "FDM" )
67
+ plt .ylim (0.0 , 100.0 )
68
+ plt .xlim (- 1.0 , + 1.0 )
69
+ plt .xlabel ("x" )
70
+ plt .ylabel ("T" )
71
+ plt .tight_layout ()
72
+ plt .legend ()
73
+
74
+ for i , var_index in enumerate (frames ):
75
+ plt .subplot (2 , len (frames ), len (frames ) + i + 1 )
76
+ plt .title (f"x = { frames_val [i ]:.2f} " )
77
+ plt .plot (
78
+ y [var_index , :],
79
+ pinn_output [var_index , :] * 75.0 ,
80
+ "r--" ,
81
+ lw = 4.0 ,
82
+ label = "pinn" ,
83
+ )
84
+ plt .plot (y [var_index , :], fdm_output [var_index , :], "b" , lw = 2.0 , label = "FDM" )
85
+ plt .ylim (0.0 , 100.0 )
86
+ plt .xlim (- 1.0 , + 1.0 )
87
+ plt .xlabel ("y" )
88
+ plt .ylabel ("T" )
89
+ plt .tight_layout ()
90
+ plt .legend ()
91
+
92
+ plt .savefig (osp .join (cfg .output_dir , "profiles.png" ))
93
+
94
+
27
95
def train (cfg : DictConfig ):
28
96
# set random seed for reproducibility
29
97
ppsci .utils .misc .set_random_seed (cfg .seed )
@@ -141,72 +209,7 @@ def train(cfg: DictConfig):
141
209
fdm_output = fdm .solve (N_EVAL , 1 ).T
142
210
mse_loss = np .mean (np .square (pinn_output - (fdm_output / 75.0 )))
143
211
logger .info (f"The norm MSE loss between the FDM and PINN is { mse_loss } " )
144
-
145
- x = input_data ["x" ].reshape (N_EVAL , N_EVAL )
146
- y = input_data ["y" ].reshape (N_EVAL , N_EVAL )
147
-
148
- plt .subplot (2 , 1 , 1 )
149
- plt .pcolormesh (x , y , pinn_output * 75.0 , cmap = "magma" )
150
- plt .colorbar ()
151
- plt .title ("PINN" )
152
- plt .xlabel ("x" )
153
- plt .ylabel ("y" )
154
- plt .tight_layout ()
155
- plt .axis ("square" )
156
-
157
- plt .subplot (2 , 1 , 2 )
158
- plt .pcolormesh (x , y , fdm_output , cmap = "magma" )
159
- plt .colorbar ()
160
- plt .xlabel ("x" )
161
- plt .ylabel ("y" )
162
- plt .title ("FDM" )
163
- plt .tight_layout ()
164
- plt .axis ("square" )
165
- plt .savefig (osp .join (cfg .output_dir , "pinn_fdm_comparison.png" ))
166
- plt .close ()
167
-
168
- frames_val = np .array ([- 0.75 , - 0.5 , - 0.25 , 0.0 , + 0.25 , + 0.5 , + 0.75 ])
169
- frames = [* map (int , (frames_val + 1 ) / 2 * (N_EVAL - 1 ))]
170
- height = 3
171
- plt .figure ("" , figsize = (len (frames ) * height , 2 * height ))
172
-
173
- for i , var_index in enumerate (frames ):
174
- plt .subplot (2 , len (frames ), i + 1 )
175
- plt .title (f"y = { frames_val [i ]:.2f} " )
176
- plt .plot (
177
- x [:, var_index ],
178
- pinn_output [:, var_index ] * 75.0 ,
179
- "r--" ,
180
- lw = 4.0 ,
181
- label = "pinn" ,
182
- )
183
- plt .plot (x [:, var_index ], fdm_output [:, var_index ], "b" , lw = 2.0 , label = "FDM" )
184
- plt .ylim (0.0 , 100.0 )
185
- plt .xlim (- 1.0 , + 1.0 )
186
- plt .xlabel ("x" )
187
- plt .ylabel ("T" )
188
- plt .tight_layout ()
189
- plt .legend ()
190
-
191
- for i , var_index in enumerate (frames ):
192
- plt .subplot (2 , len (frames ), len (frames ) + i + 1 )
193
- plt .title (f"x = { frames_val [i ]:.2f} " )
194
- plt .plot (
195
- y [var_index , :],
196
- pinn_output [var_index , :] * 75.0 ,
197
- "r--" ,
198
- lw = 4.0 ,
199
- label = "pinn" ,
200
- )
201
- plt .plot (y [var_index , :], fdm_output [var_index , :], "b" , lw = 2.0 , label = "FDM" )
202
- plt .ylim (0.0 , 100.0 )
203
- plt .xlim (- 1.0 , + 1.0 )
204
- plt .xlabel ("y" )
205
- plt .ylabel ("T" )
206
- plt .tight_layout ()
207
- plt .legend ()
208
-
209
- plt .savefig (osp .join (cfg .output_dir , "profiles.png" ))
212
+ plot (input_data , N_EVAL , pinn_output , fdm_output , cfg )
210
213
211
214
212
215
def evaluate (cfg : DictConfig ):
@@ -239,72 +242,49 @@ def evaluate(cfg: DictConfig):
239
242
fdm_output = fdm .solve (N_EVAL , 1 ).T
240
243
mse_loss = np .mean (np .square (pinn_output - (fdm_output / 75.0 )))
241
244
logger .info (f"The norm MSE loss between the FDM and PINN is { mse_loss :.5e} " )
245
+ plot (input_data , N_EVAL , pinn_output , fdm_output , cfg )
242
246
243
- x = input_data ["x" ].reshape (N_EVAL , N_EVAL )
244
- y = input_data ["y" ].reshape (N_EVAL , N_EVAL )
245
247
246
- plt .subplot (2 , 1 , 1 )
247
- plt .pcolormesh (x , y , pinn_output * 75.0 , cmap = "magma" )
248
- plt .colorbar ()
249
- plt .title ("PINN" )
250
- plt .xlabel ("x" )
251
- plt .ylabel ("y" )
252
- plt .tight_layout ()
253
- plt .axis ("square" )
248
+ def export (cfg : DictConfig ):
249
+ # set model
250
+ model = ppsci .arch .MLP (** cfg .MODEL )
254
251
255
- plt .subplot (2 , 1 , 2 )
256
- plt .pcolormesh (x , y , fdm_output , cmap = "magma" )
257
- plt .colorbar ()
258
- plt .xlabel ("x" )
259
- plt .ylabel ("y" )
260
- plt .title ("FDM" )
261
- plt .tight_layout ()
262
- plt .axis ("square" )
263
- plt .savefig (osp .join (cfg .output_dir , "pinn_fdm_comparison.png" ))
264
- plt .close ()
252
+ # initialize solver
253
+ solver = ppsci .solver .Solver (
254
+ model ,
255
+ cfg = cfg ,
256
+ )
257
+ # export model
258
+ from paddle .static import InputSpec
265
259
266
- frames_val = np . array ([ - 0.75 , - 0.5 , - 0.25 , 0.0 , + 0.25 , + 0.5 , + 0.75 ])
267
- frames = [ * map ( int , ( frames_val + 1 ) / 2 * ( N_EVAL - 1 ))]
268
- height = 3
269
- plt . figure ( "" , figsize = ( len ( frames ) * height , 2 * height ) )
260
+ input_spec = [
261
+ { key : InputSpec ([ None , 1 ], "float32" , name = key ) for key in model . input_keys },
262
+ ]
263
+ solver . export ( input_spec , cfg . INFER . export_path )
270
264
271
- for i , var_index in enumerate (frames ):
272
- plt .subplot (2 , len (frames ), i + 1 )
273
- plt .title (f"y = { frames_val [i ]:.2f} " )
274
- plt .plot (
275
- x [:, var_index ],
276
- pinn_output [:, var_index ] * 75.0 ,
277
- "r--" ,
278
- lw = 4.0 ,
279
- label = "pinn" ,
280
- )
281
- plt .plot (x [:, var_index ], fdm_output [:, var_index ], "b" , lw = 2.0 , label = "FDM" )
282
- plt .ylim (0.0 , 100.0 )
283
- plt .xlim (- 1.0 , + 1.0 )
284
- plt .xlabel ("x" )
285
- plt .ylabel ("T" )
286
- plt .tight_layout ()
287
- plt .legend ()
288
265
289
- for i , var_index in enumerate (frames ):
290
- plt .subplot (2 , len (frames ), len (frames ) + i + 1 )
291
- plt .title (f"x = { frames_val [i ]:.2f} " )
292
- plt .plot (
293
- y [var_index , :],
294
- pinn_output [var_index , :] * 75.0 ,
295
- "r--" ,
296
- lw = 4.0 ,
297
- label = "pinn" ,
298
- )
299
- plt .plot (y [var_index , :], fdm_output [var_index , :], "b" , lw = 2.0 , label = "FDM" )
300
- plt .ylim (0.0 , 100.0 )
301
- plt .xlim (- 1.0 , + 1.0 )
302
- plt .xlabel ("y" )
303
- plt .ylabel ("T" )
304
- plt .tight_layout ()
305
- plt .legend ()
266
+ def inference (cfg : DictConfig ):
267
+ from deploy .python_infer import pinn_predictor
306
268
307
- plt .savefig (osp .join (cfg .output_dir , "profiles.png" ))
269
+ predictor = pinn_predictor .PINNPredictor (cfg )
270
+ # set geometry
271
+ geom = {"rect" : ppsci .geometry .Rectangle ((- 1.0 , - 1.0 ), (1.0 , 1.0 ))}
272
+ # begin eval
273
+ N_EVAL = 100
274
+ input_data = geom ["rect" ].sample_interior (N_EVAL ** 2 , evenly = True )
275
+ output_data = predictor .predict (
276
+ {key : input_data [key ] for key in cfg .MODEL .input_keys }, cfg .INFER .batch_size
277
+ )
278
+
279
+ # mapping data to cfg.INFER.output_keys
280
+ output_data = {
281
+ store_key : output_data [infer_key ]
282
+ for store_key , infer_key in zip (cfg .MODEL .output_keys , output_data .keys ())
283
+ }["u" ].reshape (N_EVAL , N_EVAL )
284
+ fdm_output = fdm .solve (N_EVAL , 1 ).T
285
+ mse_loss = np .mean (np .square (output_data - (fdm_output / 75.0 )))
286
+ logger .info (f"The norm MSE loss between the FDM and PINN is { mse_loss :.5e} " )
287
+ plot (input_data , N_EVAL , output_data , fdm_output , cfg )
308
288
309
289
310
290
@hydra .main (version_base = None , config_path = "./conf" , config_name = "heat_pinn.yaml" )
@@ -313,8 +293,14 @@ def main(cfg: DictConfig):
313
293
train (cfg )
314
294
elif cfg .mode == "eval" :
315
295
evaluate (cfg )
296
+ elif cfg .mode == "export" :
297
+ export (cfg )
298
+ elif cfg .mode == "infer" :
299
+ inference (cfg )
316
300
else :
317
- raise ValueError (f"cfg.mode should in ['train', 'eval'], but got '{ cfg .mode } '" )
301
+ raise ValueError (
302
+ f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{ cfg .mode } '"
303
+ )
318
304
319
305
320
306
if __name__ == "__main__" :
0 commit comments