Skip to content

Commit bfcb4cf

Browse files
update visualize code
1 parent c77ec5e commit bfcb4cf

File tree

7 files changed

+37
-51
lines changed

7 files changed

+37
-51
lines changed

examples/cylinder/3d_unsteady_discrete/cylinder3d_unsteady.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,6 @@
342342
600000,
343343
label,
344344
time_list,
345-
len(time_list),
346345
"result_uvwp",
347346
)
348347
}

ppsci/data/process/transform/preprocess.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def __call__(self, data):
205205
input_item[key] = value.reshape((B * C, H, W))
206206
if value.ndim != 3:
207207
raise ValueError(
208-
"Only support squeeze data to ndim=3 now, but got ndim={value.ndim}"
208+
f"Only support squeeze data to ndim=3 now, but got ndim={value.ndim}"
209209
)
210210
if "label" in self.apply_keys:
211211
for key, value in label_item.items():
@@ -214,6 +214,6 @@ def __call__(self, data):
214214
label_item[key] = value.reshape((B * C, H, W))
215215
if value.ndim != 3:
216216
raise ValueError(
217-
"Only support squeeze data to ndim=3 now, but got ndim={value.ndim}"
217+
f"Only support squeeze data to ndim=3 now, but got ndim={value.ndim}"
218218
)
219219
return input_item, label_item, weight_item

ppsci/solver/eval.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ppsci.utils import profiler
2424

2525

26-
def eval_by_dataset(solver, epoch_id: int, log_freq: int) -> float:
26+
def _eval_by_dataset(solver, epoch_id: int, log_freq: int) -> float:
2727
"""Evaluation program by dataset.
2828
2929
Args:
@@ -63,7 +63,8 @@ def eval_by_dataset(solver, epoch_id: int, log_freq: int) -> float:
6363
_validator.input_keys, _validator.output_keys, solver.model
6464
)
6565
for output_name, output_formula in _validator.output_expr.items():
66-
evaluator.add_target_expr(output_formula, output_name)
66+
if output_name in label_dict:
67+
evaluator.add_target_expr(output_formula, output_name)
6768

6869
# forward
6970
with solver.autocast_context_manager(), solver.no_grad_context_manager():
@@ -147,7 +148,7 @@ def eval_by_dataset(solver, epoch_id: int, log_freq: int) -> float:
147148
return target_metric
148149

149150

150-
def eval_by_batch(solver, epoch_id: int, log_freq: int) -> float:
151+
def _eval_by_batch(solver, epoch_id: int, log_freq: int) -> float:
151152
"""Evaluation program by batch.
152153
153154
Args:
@@ -260,5 +261,5 @@ def eval_func(solver, epoch_id: int, log_freq: int) -> float:
260261
float: Target metric computed during evaluation.
261262
"""
262263
if solver.compute_metric_by_batch:
263-
return eval_by_batch(solver, epoch_id, log_freq)
264-
return eval_by_dataset(solver, epoch_id, log_freq)
264+
return _eval_by_batch(solver, epoch_id, log_freq)
265+
return _eval_by_dataset(solver, epoch_id, log_freq)

ppsci/solver/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
5656
_constraint.input_keys, _constraint.output_keys, solver.model
5757
)
5858
for output_name, output_formula in _constraint.output_expr.items():
59-
evaluator.add_target_expr(output_formula, output_name)
59+
if output_name in label_dict:
60+
evaluator.add_target_expr(output_formula, output_name)
6061

6162
# forward for every constraint
6263
with solver.autocast_context_manager():

ppsci/visualize/plot.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
import imageio
2222
import matplotlib
23-
import matplotlib as mpl
2423
import numpy as np
2524
import paddle
2625
from matplotlib import cm
@@ -174,8 +173,8 @@ def _save_plot_from_2d_array(
174173
"""
175174

176175
plt.close("all")
177-
mpl.rcParams["xtick.labelsize"] = 5
178-
mpl.rcParams["ytick.labelsize"] = 5
176+
matplotlib.rcParams["xtick.labelsize"] = 5
177+
matplotlib.rcParams["ytick.labelsize"] = 5
179178

180179
fig, ax = plt.subplots(
181180
len(visu_keys),
@@ -215,7 +214,7 @@ def _save_plot_from_2d_array(
215214
ticks = np.linspace(0, 1, 5)
216215
tickLabels = np.linspace(c_min, c_max, 5)
217216
tickLabels = [f"{t0:02.2f}" for t0 in tickLabels]
218-
cbar = mpl.colorbar.ColorbarBase(
217+
cbar = matplotlib.colorbar.ColorbarBase(
219218
ax_cbar, cmap=plt.get_cmap("inferno"), orientation="vertical", ticks=ticks
220219
)
221220
cbar.set_ticklabels(tickLabels, fontsize=5)

ppsci/visualize/visualizer.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -302,32 +302,21 @@ def __init__(
302302
batch_size: int = 64,
303303
label_dict: Optional[Dict[str, np.ndarray]] = None,
304304
time_list: Optional[Tuple[float, ...]] = None,
305-
num_timestamps: int = 1,
306305
prefix: str = "vtu",
307306
):
308307
self.label = label_dict
309308
self.time_list = time_list
310-
super().__init__(input_dict, output_expr, batch_size, num_timestamps, prefix)
311-
312-
def save(self, filename: str, data_dict: Dict[str, paddle.Tensor]):
313-
"""Save points result
314-
315-
Args:
316-
filename (str): Output file name with directory.
317-
data_dict (Dict[str, paddle.Tensor]): Predicted result.
318-
"""
309+
super().__init__(input_dict, output_expr, batch_size, len(time_list), prefix)
319310

311+
def save(self, filename: str, data_dict: Dict[str, np.ndarray]):
320312
n = int((next(iter(data_dict.values()))).shape[0] / self.num_timestamps)
321313
coord_keys = [x for x in self.input_dict if x != "t"]
322314
for i in range(len(self.time_list)):
323315
vtu.save_vtu_to_mesh(
324-
filename=osp.join(filename, f"predict_{i+1}.vtu"),
325-
data_dict={
326-
key: (data_dict[key].numpy()[i * n : (i + 1) * n])
327-
for key in data_dict
328-
},
329-
value_keys=self.output_expr,
330-
coord_keys=coord_keys,
316+
osp.join(filename, f"predict_{i+1}.vtu"),
317+
{key: (data_dict[key][i * n : (i + 1) * n]) for key in data_dict},
318+
coord_keys,
319+
self.output_keys,
331320
)
332321

333322

ppsci/visualize/vtu.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import os
1615
from typing import Dict
1716
from typing import Tuple
1817

@@ -93,12 +92,18 @@ def _save_vtu_from_array(filename, coord, value, value_keys, num_timestamps=1):
9392
logger.info(f"Visualization result is saved to {filename}.vtu")
9493

9594

96-
def save_vtu_from_dict(filename, data_dict, coord_keys, value_keys, num_timestamps=1):
95+
def save_vtu_from_dict(
96+
filename: str,
97+
data_dict: Dict[str, np.ndarray],
98+
coord_keys: Tuple[str, ...],
99+
value_keys: Tuple[str, ...],
100+
num_timestamps: int = 1,
101+
):
97102
"""Save dict data to '*.vtu' file.
98103
99104
Args:
100105
filename (str): Output filename.
101-
data_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Data in dict.
106+
data_dict (Dict[str, np.ndarray]): Data in dict.
102107
coord_keys (Tuple[str, ...]): Tuple of coord key. such as ("x", "y").
103108
value_keys (Tuple[str, ...]): Tuple of value key. such as ("u", "v").
104109
num_timestamps (int, optional): Number of timestamp in data_dict. Defaults to 1.
@@ -109,17 +114,9 @@ def save_vtu_from_dict(filename, data_dict, coord_keys, value_keys, num_timestam
109114
coord = [data_dict[k] for k in coord_keys if k != "t"]
110115
value = [data_dict[k] for k in value_keys] if value_keys else None
111116

112-
if isinstance(coord[0], paddle.Tensor):
113-
coord = [x.numpy() for x in coord]
114-
else:
115-
coord = [x for x in coord]
116117
coord = np.concatenate(coord, axis=1)
117118

118119
if value is not None:
119-
if isinstance(value[0], paddle.Tensor):
120-
value = [x.numpy() for x in value]
121-
else:
122-
value = [x for x in value]
123120
value = np.concatenate(value, axis=1)
124121

125122
_save_vtu_from_array(filename, coord, value, value_keys, num_timestamps)
@@ -128,26 +125,26 @@ def save_vtu_from_dict(filename, data_dict, coord_keys, value_keys, num_timestam
128125
def save_vtu_to_mesh(
129126
filename: str,
130127
data_dict: Dict[str, np.ndarray],
131-
value_keys: Tuple[str, ...],
132128
coord_keys: Tuple[str, ...],
133-
num_timestamps: int = 1,
129+
value_keys: Tuple[str, ...],
134130
):
135131
"""Save data into .vtu format by meshio.
136132
137133
Args:
138134
filename (str): File name.
139-
data_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Data in dict.
135+
data_dict (Dict[str, np.ndarray]): Data in dict.
140136
coord_keys (Tuple[str, ...]): Tuple of coord key. such as ("x", "y").
141137
value_keys (Tuple[str, ...]): Tuple of value key. such as ("u", "v").
142-
num_timestamp (int, optional): Number of timestamp in data_dict. Defaults to 1.
143138
"""
144-
path = os.path.dirname(filename)
145-
os.makedirs(path, exist_ok=True)
139+
npoint = len(next(iter(data_dict.values())))
140+
coord_ndim = len(coord_keys)
146141

147-
n = len(next(iter(data_dict.values())))
148-
m = len(coord_keys)
149142
# get the list variable transposed
150-
points = np.stack((data_dict[key] for key in coord_keys)).reshape(m, n)
151-
mesh = meshio.Mesh(points=points.T, cells=[("vertex", np.arange(n).reshape(n, 1))])
143+
points = np.stack((data_dict[key] for key in coord_keys)).reshape(
144+
coord_ndim, npoint
145+
)
146+
mesh = meshio.Mesh(
147+
points=points.T, cells=[("vertex", np.arange(npoint).reshape(npoint, 1))]
148+
)
152149
mesh.point_data = {key: data_dict[key] for key in value_keys}
153150
mesh.write(filename)

0 commit comments

Comments
 (0)