Skip to content

Commit e54915d

Browse files
1want2sleepUnityLikerHydrogenSulfate
authored
【PPSCI Export&Infer No.31】heat_pinn (#926)
* ppsci.equation.PDE.parameters/state_dict/set_state_dict api fix * ppsci.equation.PDE.parameters/state_dict/set_state_dict api fix * fix api docs in the timedomain * fix api docs of timedomain * fix api docs of timedomain * ppsci api docs fixed * ppsci api docs fixed * ppsci api docs fixed * add export and infer for bracket * updata bracket doc * solve conflict according to the branch named develop * Update examples/bracket/conf/bracket.yaml * Update examples/bracket/conf/bracket.yaml * Update examples/bracket/conf/bracket.yaml * add export&inference for bracket * add export and infer for heat_pinn * add export and infer for heat_pinn * Update examples/heat_pinn/heat_pinn.py * Update examples/heat_pinn/heat_pinn.py * Update examples/heat_pinn/conf/heat_pinn.yaml --------- Co-authored-by: krp <2934631798@qq.com> Co-authored-by: HydrogenSulfate <490868991@qq.com>
1 parent 5edd8f2 commit e54915d

File tree

3 files changed

+143
-127
lines changed

3 files changed

+143
-127
lines changed

docs/zh/examples/heat_pinn.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,18 @@
1212
python heat_pinn.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/heat_pinn/heat_pinn_pretrained.pdparams
1313
```
1414

15+
=== "模型导出命令"
16+
17+
``` sh
18+
python heat_pinn.py mode=export
19+
```
20+
21+
=== "模型推理命令"
22+
23+
``` sh
24+
python heat_pinn.py mode=infer
25+
```
26+
1527
| 预训练模型 | 指标 |
1628
|:--| :--|
1729
| [heat_pinn_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/heat_pinn/heat_pinn_pretrained.pdparams) | norm MSE loss between the FDM and PINN is 1.30174e-03 |

examples/heat_pinn/conf/heat_pinn.yaml

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
hydra:
22
run:
33
# dynamic output directory according to running time and override name
4-
dir: outputs_bracket/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
4+
dir: outputs_heat_pinn/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
55
job:
66
name: ${mode} # name of logfile
77
chdir: false # keep current working directory unchanged
@@ -50,3 +50,21 @@ TRAIN:
5050
# evaluation settings
5151
EVAL:
5252
pretrained_model_path: null
53+
54+
# inference settings
55+
INFER:
56+
pretrained_model_path: "https://paddle-org.bj.bcebos.com/paddlescience/models/heat_pinn/heat_pinn_pretrained.pdparams"
57+
export_path: ./inference/heat_pinn
58+
pdmodel_path: ${INFER.export_path}.pdmodel
59+
pdiparams_path: ${INFER.export_path}.pdiparams
60+
device: gpu
61+
engine: native
62+
precision: fp32
63+
onnx_path: ${INFER.export_path}.onnx
64+
ir_optim: true
65+
min_subgraph_size: 10
66+
gpu_mem: 2000
67+
gpu_id: 0
68+
max_batch_size: 128
69+
num_cpu_threads: 4
70+
batch_size: 128

examples/heat_pinn/heat_pinn.py

Lines changed: 112 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,74 @@
2424
from ppsci.utils import logger
2525

2626

27+
def plot(input_data, N_EVAL, pinn_output, fdm_output, cfg):
28+
x = input_data["x"].reshape(N_EVAL, N_EVAL)
29+
y = input_data["y"].reshape(N_EVAL, N_EVAL)
30+
31+
plt.subplot(2, 1, 1)
32+
plt.pcolormesh(x, y, pinn_output * 75.0, cmap="magma")
33+
plt.colorbar()
34+
plt.title("PINN")
35+
plt.xlabel("x")
36+
plt.ylabel("y")
37+
plt.tight_layout()
38+
plt.axis("square")
39+
40+
plt.subplot(2, 1, 2)
41+
plt.pcolormesh(x, y, fdm_output, cmap="magma")
42+
plt.colorbar()
43+
plt.xlabel("x")
44+
plt.ylabel("y")
45+
plt.title("FDM")
46+
plt.tight_layout()
47+
plt.axis("square")
48+
plt.savefig(osp.join(cfg.output_dir, "pinn_fdm_comparison.png"))
49+
plt.close()
50+
51+
frames_val = np.array([-0.75, -0.5, -0.25, 0.0, +0.25, +0.5, +0.75])
52+
frames = [*map(int, (frames_val + 1) / 2 * (N_EVAL - 1))]
53+
height = 3
54+
plt.figure("", figsize=(len(frames) * height, 2 * height))
55+
56+
for i, var_index in enumerate(frames):
57+
plt.subplot(2, len(frames), i + 1)
58+
plt.title(f"y = {frames_val[i]:.2f}")
59+
plt.plot(
60+
x[:, var_index],
61+
pinn_output[:, var_index] * 75.0,
62+
"r--",
63+
lw=4.0,
64+
label="pinn",
65+
)
66+
plt.plot(x[:, var_index], fdm_output[:, var_index], "b", lw=2.0, label="FDM")
67+
plt.ylim(0.0, 100.0)
68+
plt.xlim(-1.0, +1.0)
69+
plt.xlabel("x")
70+
plt.ylabel("T")
71+
plt.tight_layout()
72+
plt.legend()
73+
74+
for i, var_index in enumerate(frames):
75+
plt.subplot(2, len(frames), len(frames) + i + 1)
76+
plt.title(f"x = {frames_val[i]:.2f}")
77+
plt.plot(
78+
y[var_index, :],
79+
pinn_output[var_index, :] * 75.0,
80+
"r--",
81+
lw=4.0,
82+
label="pinn",
83+
)
84+
plt.plot(y[var_index, :], fdm_output[var_index, :], "b", lw=2.0, label="FDM")
85+
plt.ylim(0.0, 100.0)
86+
plt.xlim(-1.0, +1.0)
87+
plt.xlabel("y")
88+
plt.ylabel("T")
89+
plt.tight_layout()
90+
plt.legend()
91+
92+
plt.savefig(osp.join(cfg.output_dir, "profiles.png"))
93+
94+
2795
def train(cfg: DictConfig):
2896
# set random seed for reproducibility
2997
ppsci.utils.misc.set_random_seed(cfg.seed)
@@ -141,72 +209,7 @@ def train(cfg: DictConfig):
141209
fdm_output = fdm.solve(N_EVAL, 1).T
142210
mse_loss = np.mean(np.square(pinn_output - (fdm_output / 75.0)))
143211
logger.info(f"The norm MSE loss between the FDM and PINN is {mse_loss}")
144-
145-
x = input_data["x"].reshape(N_EVAL, N_EVAL)
146-
y = input_data["y"].reshape(N_EVAL, N_EVAL)
147-
148-
plt.subplot(2, 1, 1)
149-
plt.pcolormesh(x, y, pinn_output * 75.0, cmap="magma")
150-
plt.colorbar()
151-
plt.title("PINN")
152-
plt.xlabel("x")
153-
plt.ylabel("y")
154-
plt.tight_layout()
155-
plt.axis("square")
156-
157-
plt.subplot(2, 1, 2)
158-
plt.pcolormesh(x, y, fdm_output, cmap="magma")
159-
plt.colorbar()
160-
plt.xlabel("x")
161-
plt.ylabel("y")
162-
plt.title("FDM")
163-
plt.tight_layout()
164-
plt.axis("square")
165-
plt.savefig(osp.join(cfg.output_dir, "pinn_fdm_comparison.png"))
166-
plt.close()
167-
168-
frames_val = np.array([-0.75, -0.5, -0.25, 0.0, +0.25, +0.5, +0.75])
169-
frames = [*map(int, (frames_val + 1) / 2 * (N_EVAL - 1))]
170-
height = 3
171-
plt.figure("", figsize=(len(frames) * height, 2 * height))
172-
173-
for i, var_index in enumerate(frames):
174-
plt.subplot(2, len(frames), i + 1)
175-
plt.title(f"y = {frames_val[i]:.2f}")
176-
plt.plot(
177-
x[:, var_index],
178-
pinn_output[:, var_index] * 75.0,
179-
"r--",
180-
lw=4.0,
181-
label="pinn",
182-
)
183-
plt.plot(x[:, var_index], fdm_output[:, var_index], "b", lw=2.0, label="FDM")
184-
plt.ylim(0.0, 100.0)
185-
plt.xlim(-1.0, +1.0)
186-
plt.xlabel("x")
187-
plt.ylabel("T")
188-
plt.tight_layout()
189-
plt.legend()
190-
191-
for i, var_index in enumerate(frames):
192-
plt.subplot(2, len(frames), len(frames) + i + 1)
193-
plt.title(f"x = {frames_val[i]:.2f}")
194-
plt.plot(
195-
y[var_index, :],
196-
pinn_output[var_index, :] * 75.0,
197-
"r--",
198-
lw=4.0,
199-
label="pinn",
200-
)
201-
plt.plot(y[var_index, :], fdm_output[var_index, :], "b", lw=2.0, label="FDM")
202-
plt.ylim(0.0, 100.0)
203-
plt.xlim(-1.0, +1.0)
204-
plt.xlabel("y")
205-
plt.ylabel("T")
206-
plt.tight_layout()
207-
plt.legend()
208-
209-
plt.savefig(osp.join(cfg.output_dir, "profiles.png"))
212+
plot(input_data, N_EVAL, pinn_output, fdm_output, cfg)
210213

211214

212215
def evaluate(cfg: DictConfig):
@@ -239,72 +242,49 @@ def evaluate(cfg: DictConfig):
239242
fdm_output = fdm.solve(N_EVAL, 1).T
240243
mse_loss = np.mean(np.square(pinn_output - (fdm_output / 75.0)))
241244
logger.info(f"The norm MSE loss between the FDM and PINN is {mse_loss:.5e}")
245+
plot(input_data, N_EVAL, pinn_output, fdm_output, cfg)
242246

243-
x = input_data["x"].reshape(N_EVAL, N_EVAL)
244-
y = input_data["y"].reshape(N_EVAL, N_EVAL)
245247

246-
plt.subplot(2, 1, 1)
247-
plt.pcolormesh(x, y, pinn_output * 75.0, cmap="magma")
248-
plt.colorbar()
249-
plt.title("PINN")
250-
plt.xlabel("x")
251-
plt.ylabel("y")
252-
plt.tight_layout()
253-
plt.axis("square")
248+
def export(cfg: DictConfig):
249+
# set model
250+
model = ppsci.arch.MLP(**cfg.MODEL)
254251

255-
plt.subplot(2, 1, 2)
256-
plt.pcolormesh(x, y, fdm_output, cmap="magma")
257-
plt.colorbar()
258-
plt.xlabel("x")
259-
plt.ylabel("y")
260-
plt.title("FDM")
261-
plt.tight_layout()
262-
plt.axis("square")
263-
plt.savefig(osp.join(cfg.output_dir, "pinn_fdm_comparison.png"))
264-
plt.close()
252+
# initialize solver
253+
solver = ppsci.solver.Solver(
254+
model,
255+
cfg=cfg,
256+
)
257+
# export model
258+
from paddle.static import InputSpec
265259

266-
frames_val = np.array([-0.75, -0.5, -0.25, 0.0, +0.25, +0.5, +0.75])
267-
frames = [*map(int, (frames_val + 1) / 2 * (N_EVAL - 1))]
268-
height = 3
269-
plt.figure("", figsize=(len(frames) * height, 2 * height))
260+
input_spec = [
261+
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys},
262+
]
263+
solver.export(input_spec, cfg.INFER.export_path)
270264

271-
for i, var_index in enumerate(frames):
272-
plt.subplot(2, len(frames), i + 1)
273-
plt.title(f"y = {frames_val[i]:.2f}")
274-
plt.plot(
275-
x[:, var_index],
276-
pinn_output[:, var_index] * 75.0,
277-
"r--",
278-
lw=4.0,
279-
label="pinn",
280-
)
281-
plt.plot(x[:, var_index], fdm_output[:, var_index], "b", lw=2.0, label="FDM")
282-
plt.ylim(0.0, 100.0)
283-
plt.xlim(-1.0, +1.0)
284-
plt.xlabel("x")
285-
plt.ylabel("T")
286-
plt.tight_layout()
287-
plt.legend()
288265

289-
for i, var_index in enumerate(frames):
290-
plt.subplot(2, len(frames), len(frames) + i + 1)
291-
plt.title(f"x = {frames_val[i]:.2f}")
292-
plt.plot(
293-
y[var_index, :],
294-
pinn_output[var_index, :] * 75.0,
295-
"r--",
296-
lw=4.0,
297-
label="pinn",
298-
)
299-
plt.plot(y[var_index, :], fdm_output[var_index, :], "b", lw=2.0, label="FDM")
300-
plt.ylim(0.0, 100.0)
301-
plt.xlim(-1.0, +1.0)
302-
plt.xlabel("y")
303-
plt.ylabel("T")
304-
plt.tight_layout()
305-
plt.legend()
266+
def inference(cfg: DictConfig):
267+
from deploy.python_infer import pinn_predictor
306268

307-
plt.savefig(osp.join(cfg.output_dir, "profiles.png"))
269+
predictor = pinn_predictor.PINNPredictor(cfg)
270+
# set geometry
271+
geom = {"rect": ppsci.geometry.Rectangle((-1.0, -1.0), (1.0, 1.0))}
272+
# begin eval
273+
N_EVAL = 100
274+
input_data = geom["rect"].sample_interior(N_EVAL**2, evenly=True)
275+
output_data = predictor.predict(
276+
{key: input_data[key] for key in cfg.MODEL.input_keys}, cfg.INFER.batch_size
277+
)
278+
279+
# mapping data to cfg.INFER.output_keys
280+
output_data = {
281+
store_key: output_data[infer_key]
282+
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_data.keys())
283+
}["u"].reshape(N_EVAL, N_EVAL)
284+
fdm_output = fdm.solve(N_EVAL, 1).T
285+
mse_loss = np.mean(np.square(output_data - (fdm_output / 75.0)))
286+
logger.info(f"The norm MSE loss between the FDM and PINN is {mse_loss:.5e}")
287+
plot(input_data, N_EVAL, output_data, fdm_output, cfg)
308288

309289

310290
@hydra.main(version_base=None, config_path="./conf", config_name="heat_pinn.yaml")
@@ -313,8 +293,14 @@ def main(cfg: DictConfig):
313293
train(cfg)
314294
elif cfg.mode == "eval":
315295
evaluate(cfg)
296+
elif cfg.mode == "export":
297+
export(cfg)
298+
elif cfg.mode == "infer":
299+
inference(cfg)
316300
else:
317-
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
301+
raise ValueError(
302+
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
303+
)
318304

319305

320306
if __name__ == "__main__":

0 commit comments

Comments
 (0)