Skip to content

Commit de42263

Browse files
1. fix vispoints bug in cylinder2d_unsteady_Re100; 2. remove dot at end of print info
1 parent 4a84de3 commit de42263

File tree

4 files changed

+40
-36
lines changed

4 files changed

+40
-36
lines changed

examples/cylinder/2d_unsteady/cylinder2d_unsteady_Re100.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,14 @@
1313
# limitations under the License.
1414

1515
import numpy as np
16-
from paddle import fluid
1716

1817
import ppsci
1918
from ppsci.utils import config
2019
from ppsci.utils import logger
2120
from ppsci.utils import reader
2221

2322
if __name__ == "__main__":
24-
fluid.core.set_prim_eager_enabled(True)
25-
23+
# fluid.core.set_prim_eager_enabled(True)
2624
args = config.parse_args()
2725
# set random seed for reproducibility
2826
ppsci.utils.misc.set_random_seed(42)
@@ -89,7 +87,7 @@
8987
ITERS_PER_EPOCH = 1
9088

9189
# pde/bc/sup constraint use t1~tn, initial constraint use t0
92-
NPOINT_PDE, ntime_pde = 9420, len(train_timestamps)
90+
NPOINT_PDE, NTIME_PDE = 9420, len(train_timestamps)
9391
NPOINT_INLET_CYLINDER = 161
9492
NPOINT_OUTLET = 81
9593
ALIAS_DICT = {"x": "Points:0", "y": "Points:1", "u": "U:0", "v": "U:1"}
@@ -101,7 +99,7 @@
10199
geom["time_rect"],
102100
{
103101
"dataset": "IterableNamedArrayDataset",
104-
"batch_size": NPOINT_PDE * ntime_pde,
102+
"batch_size": NPOINT_PDE * NTIME_PDE,
105103
"iters_per_epoch": ITERS_PER_EPOCH,
106104
},
107105
ppsci.loss.MSELoss("mean"),
@@ -202,14 +200,15 @@
202200

203201
# set visualizer(optional)
204202
vis_points = geom["time_rect_eval"].sample_interior(
205-
(NPOINT_PDE + NPOINT_INLET_CYLINDER + NPOINT_OUTLET) * NUM_TIMESTAMPS
203+
(NPOINT_PDE + NPOINT_INLET_CYLINDER + NPOINT_OUTLET) * NUM_TIMESTAMPS,
204+
evenly=True,
206205
)
207206
visualizer = {
208-
"visulzie_u": ppsci.visualize.VisualizerVtu(
207+
"visulzie_u_v_p": ppsci.visualize.VisualizerVtu(
209208
vis_points,
210209
{"u": lambda d: d["u"], "v": lambda d: d["v"], "p": lambda d: d["p"]},
211210
num_timestamps=NUM_TIMESTAMPS,
212-
prefix="result_u",
211+
prefix="result_u_v_p",
213212
)
214213
}
215214

ppsci/solver/solver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ def visualize(self, epoch_id: int = 0):
439439
self.visu_func = ppsci.solver.visu.visualize_func
440440

441441
self.visu_func(self, epoch_id)
442-
logger.info(f"[Visualize][Epoch {epoch_id}] Finished visualization.")
442+
logger.info(f"[Visualize][Epoch {epoch_id}] Finished visualization")
443443

444444
if train_state:
445445
self.model.train()
@@ -516,7 +516,7 @@ def export(self):
516516
export_dir = self.cfg["Global"]["save_inference_dir"]
517517
save_path = os.path.join(export_dir, "inference")
518518
paddle.jit.save(static_model, save_path)
519-
logger.info(f"The inference model has been exported to {export_dir}.")
519+
logger.info(f"The inference model has been exported to {export_dir}")
520520

521521
def autocast_context_manager(self) -> contextlib.AbstractContextManager:
522522
"""Autocast context manager for Auto Mix Precision.

ppsci/utils/download.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,7 @@ def _download(url, path, md5sum=None):
132132
req = requests.get(url, stream=True)
133133
except Exception as e: # requests.exceptions.ConnectionError
134134
logger.info(
135-
"Downloading {} from {} failed {} times with exception {}".format(
136-
fname, url, retry_cnt + 1, str(e)
137-
)
135+
f"Downloading {fname} from {url} failed {retry_cnt + 1} times with exception {str(e)}"
138136
)
139137
time.sleep(1)
140138
continue

ppsci/utils/save_load.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ def _mkdir_if_not_exist(path):
3434
if not os.path.exists(path):
3535
try:
3636
os.makedirs(path)
37-
except OSError as e:
38-
if e.errno == errno.EEXIST and os.path.isdir(path):
39-
logger.warning(f"{path} already created")
37+
except OSError as os_err:
38+
if os_err.errno == errno.EEXIST and os.path.isdir(path):
39+
logger.warning(f"{path} already created.")
4040
else:
41-
raise OSError(f"Failed to mkdir {path}")
41+
raise OSError(f"Failed to mkdir {path}.")
4242

4343

4444
def _load_pretrain_from_path(model, path, equation=None):
@@ -49,17 +49,20 @@ def _load_pretrain_from_path(model, path, equation=None):
4949
path (str, optional): Pretrained model path.
5050
equation (Optional[Dict[str, ppsci.equation.PDE]]): Equations. Defaults to None.
5151
"""
52-
if not (os.path.isdir(path) or os.path.exists(path + ".pdparams")):
52+
if not (os.path.isdir(path) or os.path.exists(f"{path}.pdparams")):
5353
raise FileNotFoundError(
5454
f"Pretrained model path {path}.pdparams does not exists."
5555
)
5656

57-
param_state_dict = paddle.load(path + ".pdparams")
57+
param_state_dict = paddle.load(f"{path}.pdparams")
5858
model.set_dict(param_state_dict)
5959
if equation is not None:
60-
equation_dict = paddle.load(path + ".pdeqn")
61-
for name, _equation in equation.items():
62-
_equation.set_state_dict(equation_dict[name])
60+
if not os.path.exists(f"{path}.pdeqn"):
61+
logger.warning(f"{path}.pdeqn not found.")
62+
else:
63+
equation_dict = paddle.load(f"{path}.pdeqn")
64+
for name, _equation in equation.items():
65+
_equation.set_state_dict(equation_dict[name])
6366

6467
logger.info(f"Finish loading pretrained model from {path}")
6568

@@ -92,28 +95,32 @@ def load_checkpoint(
9295
Returns:
9396
Dict[str, Any]: Loaded metric information.
9497
"""
95-
if not os.path.exists(path + ".pdparams"):
98+
if not os.path.exists(f"{path}.pdparams"):
9699
raise FileNotFoundError(f"{path}.pdparams not exist.")
97-
if not os.path.exists(path + ".pdopt"):
100+
if not os.path.exists(f"{path}.pdopt"):
98101
raise FileNotFoundError(f"{path}.pdopt not exist.")
99-
if grad_scaler is not None and not os.path.exists(path + ".pdscaler"):
102+
if grad_scaler is not None and not os.path.exists(f"{path}.pdscaler"):
100103
raise FileNotFoundError(f"{path}.scaler not exist.")
101104

102105
# load state dict
103-
param_dict = paddle.load(path + ".pdparams")
104-
optim_dict = paddle.load(path + ".pdopt")
105-
metric_dict = paddle.load(path + ".pdstates")
106+
param_dict = paddle.load(f"{path}.pdparams")
107+
optim_dict = paddle.load(f"{path}.pdopt")
108+
metric_dict = paddle.load(f"{path}.pdstates")
106109
if grad_scaler is not None:
107-
scaler_dict = paddle.load(path + ".pdscaler")
110+
scaler_dict = paddle.load(f"{path}.pdscaler")
108111
if equation is not None:
109-
equation_dict = paddle.load(path + ".pdeqn")
112+
if not os.path.exists(f"{path}.pdeqn"):
113+
logger.warning(f"{path}.pdeqn not found.")
114+
equation_dict = None
115+
else:
116+
equation_dict = paddle.load(f"{path}.pdeqn")
110117

111118
# set state dict
112119
model.set_state_dict(param_dict)
113120
optimizer.set_state_dict(optim_dict)
114121
if grad_scaler is not None:
115122
grad_scaler.load_state_dict(scaler_dict)
116-
if equation is not None:
123+
if equation is not None and equation_dict is not None:
117124
for name, _equation in equation.items():
118125
_equation.set_state_dict(equation_dict[name])
119126

@@ -141,15 +148,15 @@ def save_checkpoint(
141148
_mkdir_if_not_exist(model_dir)
142149
model_path = os.path.join(model_dir, prefix)
143150

144-
paddle.save(model.state_dict(), model_path + ".pdparams")
145-
paddle.save(optimizer.state_dict(), model_path + ".pdopt")
146-
paddle.save(metric, model_path + ".pdstates")
151+
paddle.save(model.state_dict(), f"{model_path}.pdparams")
152+
paddle.save(optimizer.state_dict(), f"{model_path}.pdopt")
153+
paddle.save(metric, f"{model_path}.pdstates")
147154
if grad_scaler is not None:
148-
paddle.save(grad_scaler.state_dict(), model_path + ".pdscaler")
155+
paddle.save(grad_scaler.state_dict(), f"{model_path}.pdscaler")
149156
if equation is not None:
150157
paddle.save(
151158
{key: eq.state_dict() for key, eq in equation.items()},
152-
model_path + ".pdeqn",
159+
f"{model_path}.pdeqn",
153160
)
154161

155162
logger.info(f"Finish saving checkpoint to {model_path}")

0 commit comments

Comments
 (0)