Skip to content

Commit a00eeee

Browse files
3d cylinder merge (#249)
* delete useless files * 3d cylinder new api * The precision is already aligned * merge vtk generator * using meshio for writing vtk * fixed some error in sampler cfg * delete useless print and import * merge vtk_generator into dataset and visulizer * Delete useless inputs and rewrite ReadMe * alignment loss * delete some useless vars * delete enum * simplify api and fix bugs * weight_exper default value : -1.0 -> -1 * delete banished functions * fit develop code style * delete useless line * refactor hardcode and reformat by reviews * delete useless lines * merge default_collate_fn_allow_none * delete useless file * combine visu func, change loss weight design * keep blank lines unchanged * refine code * refine visu and visualizer3D * refine visualizer3D * fix visulizer to fit visu * fix debug lines * delete lines and normalize comment * reformat function and docstring * reformat save_vtu_to_mesh interface * fix missing bug * refactor pointcloud * Bug fixes and delete filter --------- Co-authored-by: HydrogenSulfate <490868991@qq.com>
1 parent d8e07c7 commit a00eeee

File tree

20 files changed

+1001
-83
lines changed

20 files changed

+1001
-83
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
[//]: <> (title: Flow around a cylinder use case tutorial, author: Xiandong Liu @liuxiandong at baidu.com)
2+
3+
4+
# Flow around a Cylinder
5+
6+
This guide introduces to how to build a PINN model for simulating the flow around a cylinder in PaddleScience.
7+
In this example, two versions are provided. It is recommended to pay attention to the baseline version first.
8+
If you want higher training speed or want to run on distributed systems, please pay attention to the optimize version.
9+
10+
11+
## Run
12+
This guide introduces to how to build a PINN model for simulating the flow around a cylinder in PaddleScience.
13+
Run the command as follows:
14+
```
15+
cd ./PaddleScience/examples/cylinder/3d_unsteady_discrete
16+
python3.7 cylinder3d_unsteady.py
17+
```
Lines changed: 367 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,367 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Created in Mar. 2023
17+
@author: Guan Wang
18+
"""
19+
20+
import numpy as np
21+
22+
import ppsci
23+
import ppsci.data.process.transform as transform
24+
import ppsci.utils.reader as reader
25+
from ppsci.utils import logger
26+
27+
if __name__ == "__main__":
28+
# set random seed for reproducibility
29+
ppsci.utils.misc.set_random_seed(42)
30+
31+
# set output directory
32+
output_dir = "./output_cylinder3d_unsteady_Re3900"
33+
34+
# set reference file name without time index
35+
ref_file = "data/LBM_result/cylinder3d_2023_1_31_LBM_"
36+
37+
# initialize logger
38+
logger.init_logger("ppsci", f"{output_dir}/train.log", "info")
39+
40+
# set model
41+
model = ppsci.arch.MLP(
42+
("t", "x", "y", "z"),
43+
("u", "v", "w", "p"),
44+
5,
45+
512,
46+
)
47+
48+
# set equation and necessary constant
49+
RENOLDS_NUMBER = 3900
50+
U0 = 0.1
51+
D_CYLINDER = 80
52+
RHO = 1
53+
NU = RHO * U0 * D_CYLINDER / RENOLDS_NUMBER
54+
55+
T_STAR = D_CYLINDER / U0 # 800
56+
XYZ_STAR = D_CYLINDER # 80
57+
UVW_STAR = U0 # 0.1
58+
P_STAR = RHO * U0 * U0 # 0.01
59+
# N-S, Re=3900, D=80, u=0.1, nu=80/3900; nu = rho u D / Re = 1.0 * 0.1 * 80 / 3900
60+
equation = {"NavierStokes": ppsci.equation.NavierStokes(NU, RHO, 3, True)}
61+
62+
# set geometry
63+
norm_factor = {
64+
"t": T_STAR,
65+
"x": XYZ_STAR,
66+
"y": XYZ_STAR,
67+
"z": XYZ_STAR,
68+
"u": UVW_STAR,
69+
"v": UVW_STAR,
70+
"w": UVW_STAR,
71+
"p": P_STAR,
72+
}
73+
normalize = transform.Scale({key: 1 / value for key, value in norm_factor.items()})
74+
interior_data = reader.load_vtk_with_time_file(
75+
"data/sample_points/interior_txyz.vtu"
76+
)
77+
geom = {
78+
"interior": ppsci.geometry.PointCloud(
79+
interior=normalize(interior_data),
80+
coord_keys=("t", "x", "y", "z"),
81+
)
82+
}
83+
84+
# set dataloader config
85+
batchsize_interior = 4000
86+
batchsize_inlet = 256
87+
batchsize_outlet = 256
88+
batchsize_cylinder = 256
89+
batchsize_top = 1280
90+
batchsize_bottom = 1280
91+
batchsize_ic = 6400
92+
batchsize_supervised = 6400
93+
94+
# set time array
95+
INITIAL_TIME = 200000
96+
START_TIME = 200050
97+
END_TIME = 204950
98+
TIME_STEP = 50
99+
TIME_NUMBER = int((END_TIME - START_TIME) / TIME_STEP) + 1
100+
time_list = np.linspace(
101+
int((START_TIME - INITIAL_TIME) / TIME_STEP),
102+
int((END_TIME - INITIAL_TIME) / TIME_STEP),
103+
TIME_NUMBER,
104+
endpoint=True,
105+
).astype("int64")
106+
time_tmp = time_list * TIME_STEP
107+
time_index = np.random.choice(time_list, int(TIME_NUMBER / 2.5), replace=False)
108+
time_index.sort()
109+
time_array = time_index * TIME_STEP
110+
111+
# set constraint
112+
train_dataloader_cfg = {
113+
"sampler": {
114+
"name": "BatchSampler",
115+
"shuffle": False,
116+
"drop_last": False,
117+
},
118+
"num_workers": 1,
119+
}
120+
# interior data
121+
pde_constraint = ppsci.constraint.InteriorConstraint(
122+
equation["NavierStokes"].equations,
123+
{"continuity": 0, "momentum_x": 0, "momentum_y": 0, "momentum_z": 0},
124+
geom["interior"],
125+
evenly=True,
126+
dataloader_cfg={
127+
**train_dataloader_cfg,
128+
"iters_per_epoch": int(geom["interior"].len / batchsize_interior),
129+
"dataset": "NamedArrayDataset",
130+
"batch_size": batchsize_interior,
131+
},
132+
loss=ppsci.loss.MSELoss("mean", 1),
133+
name="INTERIOR",
134+
)
135+
136+
norm_cfg = {
137+
"Scale": {"scale": {key: 1 / value for key, value in norm_factor.items()}}
138+
}
139+
bc_inlet = ppsci.constraint.SupervisedConstraint(
140+
dataloader_cfg={
141+
**train_dataloader_cfg,
142+
"dataset": {
143+
"name": "VtuDataset",
144+
"file_path": "data/sample_points/inlet_txyz.vtu",
145+
"input_keys": model.input_keys,
146+
"label_keys": ("u", "v", "w"),
147+
"labels": {"u": 0.1, "v": 0, "w": 0},
148+
"transforms": [norm_cfg],
149+
},
150+
"batch_size": batchsize_inlet,
151+
},
152+
loss=ppsci.loss.MSELoss("mean", 2),
153+
name="BC_INLET",
154+
)
155+
bc_cylinder = ppsci.constraint.SupervisedConstraint(
156+
dataloader_cfg={
157+
**train_dataloader_cfg,
158+
"dataset": {
159+
"name": "VtuDataset",
160+
"file_path": "data/sample_points/cylinder_txyz.vtu",
161+
"input_keys": model.input_keys,
162+
"label_keys": ("u", "v", "w"),
163+
"labels": {"u": 0, "v": 0, "w": 0},
164+
"transforms": [norm_cfg],
165+
},
166+
"batch_size": batchsize_cylinder,
167+
},
168+
loss=ppsci.loss.MSELoss("mean", 5),
169+
name="BC_CYLINDER",
170+
)
171+
bc_outlet = ppsci.constraint.SupervisedConstraint(
172+
dataloader_cfg={
173+
**train_dataloader_cfg,
174+
"dataset": {
175+
"name": "VtuDataset",
176+
"file_path": "data/sample_points/outlet_txyz.vtu",
177+
"input_keys": model.input_keys,
178+
"label_keys": ("p",),
179+
"labels": {"p": 0},
180+
"transforms": [norm_cfg],
181+
},
182+
"batch_size": batchsize_outlet,
183+
},
184+
loss=ppsci.loss.MSELoss("mean", 1),
185+
name="BC_OUTLET",
186+
)
187+
188+
bc_top = ppsci.constraint.SupervisedConstraint(
189+
dataloader_cfg={
190+
**train_dataloader_cfg,
191+
"dataset": {
192+
"name": "VtuDataset",
193+
"file_path": "data/sample_points/top_txyz.vtu",
194+
"input_keys": model.input_keys,
195+
"label_keys": ("u", "v", "w"),
196+
"labels": {"u": 0.1, "v": 0, "w": 0},
197+
"transforms": [norm_cfg],
198+
},
199+
"batch_size": batchsize_top,
200+
},
201+
loss=ppsci.loss.MSELoss("mean", 2),
202+
name="BC_TOP",
203+
)
204+
205+
bc_bottom = ppsci.constraint.SupervisedConstraint(
206+
dataloader_cfg={
207+
**train_dataloader_cfg,
208+
"dataset": {
209+
"name": "VtuDataset",
210+
"file_path": "data/sample_points/bottom_txyz.vtu",
211+
"input_keys": model.input_keys,
212+
"label_keys": ("u", "v", "w"),
213+
"labels": {"u": 0.1, "v": 0, "w": 0},
214+
"transforms": [norm_cfg],
215+
},
216+
"batch_size": batchsize_bottom,
217+
},
218+
loss=ppsci.loss.MSELoss("mean", 2),
219+
name="BC_BOTTOM",
220+
)
221+
ic = ppsci.constraint.SupervisedConstraint(
222+
dataloader_cfg={
223+
**train_dataloader_cfg,
224+
"dataset": {
225+
"name": "VtuDataset",
226+
"file_path": ref_file,
227+
"input_keys": model.input_keys,
228+
"label_keys": ("u", "v", "w"),
229+
"time_step": TIME_STEP,
230+
"time_index": (0,),
231+
"transforms": [norm_cfg],
232+
},
233+
"batch_size": batchsize_ic,
234+
},
235+
loss=ppsci.loss.MSELoss("mean", 5),
236+
name="IC",
237+
)
238+
sup = ppsci.constraint.SupervisedConstraint(
239+
dataloader_cfg={
240+
**train_dataloader_cfg,
241+
"dataset": {
242+
"name": "VtuDataset",
243+
"file_path": "data/sup_data/supervised_",
244+
"input_keys": model.input_keys,
245+
"label_keys": ("u", "v", "w"),
246+
"time_step": TIME_STEP,
247+
"time_index": time_index,
248+
"transforms": (norm_cfg,),
249+
},
250+
"batch_size": batchsize_supervised,
251+
},
252+
loss=ppsci.loss.MSELoss("mean", 10),
253+
name="SUP",
254+
)
255+
# wrap constraints together
256+
constraint = {
257+
pde_constraint.name: pde_constraint,
258+
bc_inlet.name: bc_inlet,
259+
bc_cylinder.name: bc_cylinder,
260+
bc_outlet.name: bc_outlet,
261+
bc_top.name: bc_top,
262+
bc_bottom.name: bc_bottom,
263+
ic.name: ic,
264+
sup.name: sup,
265+
}
266+
267+
# set training hyper-parameters
268+
epochs = 400000
269+
lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine(
270+
epochs=epochs,
271+
iters_per_epoch=1,
272+
learning_rate=0.001,
273+
warmup_epoch=int(epochs * 0.125),
274+
)()
275+
276+
# set optimizer
277+
optimizer = ppsci.optimizer.Adam(learning_rate=lr_scheduler)((model,))
278+
279+
# Read validation reference for time step : 0, 99
280+
lbm_0_input, lbm_0_label = reader.load_vtk_file(
281+
ref_file, TIME_STEP, (0,), model.input_keys, model.output_keys
282+
)
283+
lbm_0_dict = {**normalize(lbm_0_input), **normalize(lbm_0_label)}
284+
285+
# set visualizer(optional)
286+
eval_dataloader_cfg = {
287+
"sampler": {
288+
"name": "BatchSampler",
289+
"shuffle": False,
290+
"drop_last": False,
291+
},
292+
"num_workers": 0,
293+
}
294+
validator = {
295+
"Residual": ppsci.validate.SupervisedValidator(
296+
dataloader_cfg={
297+
**eval_dataloader_cfg,
298+
"dataset": {
299+
"name": "VtuDataset",
300+
"file_path": ref_file,
301+
"input_keys": model.input_keys,
302+
"label_keys": ("u", "v", "w"),
303+
"time_step": TIME_STEP,
304+
"time_index": (0,),
305+
"transforms": [norm_cfg],
306+
},
307+
"total_size": len(next(iter(lbm_0_dict.values()))),
308+
"batch_size": 1024,
309+
},
310+
loss=ppsci.loss.MSELoss("mean"),
311+
metric={"MSE": ppsci.metric.MSE()},
312+
name="Residual",
313+
),
314+
}
315+
316+
# set visualizer(optional)
317+
onestep_input, _ = reader.load_vtk_file(ref_file, 0, [0], model.input_keys, ())
318+
data_len_for_onestep = len(next(iter(onestep_input.values())))
319+
input_dict = {
320+
"t": np.concatenate(
321+
[np.full((data_len_for_onestep, 1), t, "float32") for t in time_tmp], axis=0
322+
),
323+
"x": np.tile(onestep_input["x"], (len(time_tmp), 1)),
324+
"y": np.tile(onestep_input["y"], (len(time_tmp), 1)),
325+
"z": np.tile(onestep_input["z"], (len(time_tmp), 1)),
326+
}
327+
input_dict = normalize(input_dict)
328+
_, label = reader.load_vtk_file(
329+
ref_file, TIME_STEP, time_list, model.input_keys, model.output_keys
330+
)
331+
332+
denormalize = transform.Scale(norm_factor)
333+
visualizer = {
334+
"visulzie_uvwp": ppsci.visualize.Visualizer3D(
335+
input_dict,
336+
{
337+
"u": lambda out: out["u"] * norm_factor["u"],
338+
"v": lambda out: out["v"] * norm_factor["v"],
339+
"w": lambda out: out["w"] * norm_factor["w"],
340+
"p": lambda out: out["p"] * norm_factor["p"],
341+
},
342+
600000,
343+
label,
344+
time_list,
345+
len(time_list),
346+
"result_uvwp",
347+
)
348+
}
349+
350+
# initialize solver
351+
solver = ppsci.solver.Solver(
352+
model,
353+
constraint,
354+
output_dir,
355+
optimizer,
356+
lr_scheduler,
357+
epochs,
358+
1,
359+
save_freq=1000,
360+
eval_during_train=False,
361+
eval_freq=1000,
362+
equation=equation,
363+
geom=None,
364+
validator=validator,
365+
)
366+
# train model
367+
solver.train()

0 commit comments

Comments
 (0)