66
66
]
67
67
68
68
69
- def _save_plot_from_1d_array (filename , coord , value , value_keys , num_timestamp = 1 ):
69
+ def _save_plot_from_1d_array (filename , coord , value , value_keys , num_timestamps = 1 ):
70
70
"""Save plot from given 1D data.
71
71
72
72
Args:
73
73
filename (str): Filename.
74
74
coord (np.ndarray): Coordinate array.
75
75
value (Dict[str, np.ndarray]): Dict of value array.
76
76
value_keys (Tuple[str, ...]): Value keys.
77
- num_timestamp (int, optional): Number of timestamps coord/value contains. Defaults to 1.
77
+ num_timestamps (int, optional): Number of timestamps coord/value contains. Defaults to 1.
78
78
"""
79
- fig , a = plt .subplots (len (value_keys ), num_timestamp , squeeze = False )
79
+ fig , a = plt .subplots (len (value_keys ), num_timestamps , squeeze = False )
80
80
fig .subplots_adjust (hspace = 0.8 )
81
81
82
- len_ts = len (coord ) // num_timestamp
83
- for t in range (num_timestamp ):
82
+ len_ts = len (coord ) // num_timestamps
83
+ for t in range (num_timestamps ):
84
84
st = t * len_ts
85
85
ed = (t + 1 ) * len_ts
86
86
coord_t = coord [st :ed ]
@@ -93,29 +93,29 @@ def _save_plot_from_1d_array(filename, coord, value, value_keys, num_timestamp=1
93
93
color = cnames [i ],
94
94
label = key ,
95
95
)
96
- if num_timestamp > 1 :
96
+ if num_timestamps > 1 :
97
97
a [i ][t ].set_title (f"{ key } (t={ t } )" )
98
98
else :
99
99
a [i ][t ].set_title (f"{ key } " )
100
100
a [i ][t ].grid ()
101
101
a [i ][t ].legend ()
102
102
103
- if num_timestamp == 1 :
103
+ if num_timestamps == 1 :
104
104
fig .savefig (filename , dpi = 300 )
105
105
else :
106
106
fig .savefig (f"{ filename } _{ t } " , dpi = 300 )
107
107
108
- if num_timestamp == 1 :
108
+ if num_timestamps == 1 :
109
109
logger .info (f"1D result is saved to { filename } .png" )
110
110
else :
111
111
logger .info (
112
112
f"1D result is saved to { filename } _0.png"
113
- f" ~ { filename } _{ num_timestamp - 1 } .png"
113
+ f" ~ { filename } _{ num_timestamps - 1 } .png"
114
114
)
115
115
116
116
117
117
def save_plot_from_1d_dict (
118
- filename , data_dict , coord_keys , value_keys , num_timestamp = 1
118
+ filename , data_dict , coord_keys , value_keys , num_timestamps = 1
119
119
):
120
120
"""Plot dict data as file.
121
121
@@ -124,7 +124,7 @@ def save_plot_from_1d_dict(
124
124
data_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Data in dict.
125
125
coord_keys (Tuple[str, ...]): Tuple of coord key. such as ("x", "y").
126
126
value_keys (Tuple[str, ...]): Tuple of value key. such as ("u", "v").
127
- num_timestamp (int, optional): Number of timestamp in data_dict. Defaults to 1.
127
+ num_timestamps (int, optional): Number of timestamp in data_dict. Defaults to 1.
128
128
"""
129
129
space_ndim = len (coord_keys ) - int ("t" in coord_keys )
130
130
if space_ndim not in [1 , 2 , 3 ]:
@@ -146,14 +146,14 @@ def save_plot_from_1d_dict(
146
146
value = [x for x in value ]
147
147
value = np .concatenate (value , axis = 1 )
148
148
149
- _save_plot_from_1d_array (filename , coord , value , value_keys , num_timestamp )
149
+ _save_plot_from_1d_array (filename , coord , value , value_keys , num_timestamps )
150
150
151
151
152
152
def _save_plot_from_2d_array (
153
153
filename : str ,
154
154
visu_data : Tuple [np .ndarray , ...],
155
155
visu_keys : Tuple [str , ...],
156
- num_timestamp : int = 1 ,
156
+ num_timestamps : int = 1 ,
157
157
stride : int = 1 ,
158
158
xticks : Optional [Tuple [float , ...]] = None ,
159
159
yticks : Optional [Tuple [float , ...]] = None ,
@@ -164,7 +164,7 @@ def _save_plot_from_2d_array(
164
164
filename (str): Filename.
165
165
visu_data (Tuple[np.ndarray, ...]): Data that requires visualization.
166
166
visu_keys (Tuple[str, ...]]): Keys for visualizing data. such as ("u", "v").
167
- num_timestamp (int, optional): Number of timestamps coord/value contains. Defaults to 1.
167
+ num_timestamps (int, optional): Number of timestamps coord/value contains. Defaults to 1.
168
168
stride (int, optional): The time stride of visualization. Defaults to 1.
169
169
xticks (Optional[Tuple[float, ...]]): Tuple of xtick locations. Defaults to None.
170
170
yticks (Optional[Tuple[float, ...]]): Tuple of ytick locations. Defaults to None.
@@ -176,10 +176,10 @@ def _save_plot_from_2d_array(
176
176
177
177
fig , ax = plt .subplots (
178
178
len (visu_keys ),
179
- num_timestamp ,
179
+ num_timestamps ,
180
180
squeeze = False ,
181
181
sharey = True ,
182
- figsize = (num_timestamp , len (visu_keys )),
182
+ figsize = (num_timestamps , len (visu_keys )),
183
183
)
184
184
fig .subplots_adjust (hspace = 0.3 )
185
185
target_flag = any (["target" in key for key in visu_keys ])
@@ -188,7 +188,7 @@ def _save_plot_from_2d_array(
188
188
c_max = np .amax (data )
189
189
c_min = np .amin (data )
190
190
191
- for t_idx in range (num_timestamp ):
191
+ for t_idx in range (num_timestamps ):
192
192
t = t_idx * stride
193
193
ax [i , t_idx ].imshow (
194
194
data [t , :, :],
@@ -223,7 +223,7 @@ def save_plot_from_2d_dict(
223
223
filename : str ,
224
224
data_dict : Dict [str , Union [np .ndarray , paddle .Tensor ]],
225
225
visu_keys : Tuple [str , ...],
226
- num_timestamp : int = 1 ,
226
+ num_timestamps : int = 1 ,
227
227
stride : int = 1 ,
228
228
xticks : Optional [Tuple [float , ...]] = None ,
229
229
yticks : Optional [Tuple [float , ...]] = None ,
@@ -234,7 +234,7 @@ def save_plot_from_2d_dict(
234
234
filename (str): Output filename.
235
235
data_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Data in dict.
236
236
visu_keys (Tuple[str, ...]): Keys for visualizing data. such as ("u", "v").
237
- num_timestamp (int, optional): Number of timestamp in data_dict. Defaults to 1.
237
+ num_timestamps (int, optional): Number of timestamp in data_dict. Defaults to 1.
238
238
stride (int, optional): The time stride of visualization. Defaults to 1.
239
239
xticks (Optional[Tuple[float,...]]): The list of xtick locations. Defaults to None.
240
240
yticks (Optional[Tuple[float,...]]): The list of ytick locations. Defaults to None.
@@ -243,7 +243,7 @@ def save_plot_from_2d_dict(
243
243
if isinstance (visu_data [0 ], paddle .Tensor ):
244
244
visu_data = [x .numpy () for x in visu_data ]
245
245
_save_plot_from_2d_array (
246
- filename , visu_data , visu_keys , num_timestamp , stride , xticks , yticks
246
+ filename , visu_data , visu_keys , num_timestamps , stride , xticks , yticks
247
247
)
248
248
249
249
@@ -305,21 +305,21 @@ def _save_plot_from_3d_array(
305
305
filename : str ,
306
306
visu_data : Tuple [np .ndarray , ...],
307
307
visu_keys : Tuple [str , ...],
308
- num_timestamp : int = 1 ,
308
+ num_timestamps : int = 1 ,
309
309
):
310
310
"""Save plot from given 3D data.
311
311
312
312
Args:
313
313
filename (str): Filename.
314
314
visu_data (Tuple[np.ndarray, ...]): Data that requires visualization.
315
315
visu_keys (Tuple[str, ...]]): Keys for visualizing data. such as ("u", "v").
316
- num_timestamp (int, optional): Number of timestamps coord/value contains. Defaults to 1.
316
+ num_timestamps (int, optional): Number of timestamps coord/value contains. Defaults to 1.
317
317
"""
318
318
319
319
fig = plt .figure (figsize = (10 , 10 ))
320
- len_ts = len (visu_data [0 ]) // num_timestamp
321
- for t in range (num_timestamp ):
322
- ax = fig .add_subplot (1 , num_timestamp , t + 1 , projection = "3d" )
320
+ len_ts = len (visu_data [0 ]) // num_timestamps
321
+ for t in range (num_timestamps ):
322
+ ax = fig .add_subplot (1 , num_timestamps , t + 1 , projection = "3d" )
323
323
st = t * len_ts
324
324
ed = (t + 1 ) * len_ts
325
325
visu_data_t = [data [st :ed ] for data in visu_data ]
@@ -340,37 +340,37 @@ def _save_plot_from_3d_array(
340
340
loc = "upper right" ,
341
341
framealpha = 0.95 ,
342
342
)
343
- if num_timestamp == 1 :
343
+ if num_timestamps == 1 :
344
344
fig .savefig (filename , dpi = 300 )
345
345
else :
346
346
fig .savefig (f"{ filename } _{ t } " , dpi = 300 )
347
347
348
- if num_timestamp == 1 :
348
+ if num_timestamps == 1 :
349
349
logger .info (f"3D result is saved to { filename } .png" )
350
350
else :
351
351
logger .info (
352
352
f"3D result is saved to { filename } _0.png"
353
- f" ~ { filename } _{ num_timestamp - 1 } .png"
353
+ f" ~ { filename } _{ num_timestamps - 1 } .png"
354
354
)
355
355
356
356
357
357
def save_plot_from_3d_dict (
358
358
filename : str ,
359
359
data_dict : Dict [str , Union [np .ndarray , paddle .Tensor ]],
360
360
visu_keys : Tuple [str , ...],
361
- num_timestamp : int = 1 ,
361
+ num_timestamps : int = 1 ,
362
362
):
363
363
"""Plot dict data as file.
364
364
365
365
Args:
366
366
filename (str): Output filename.
367
367
data_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Data in dict.
368
368
visu_keys (Tuple[str, ...]): Keys for visualizing data. such as ("u", "v").
369
- num_timestamp (int, optional): Number of timestamp in data_dict. Defaults to 1.
369
+ num_timestamps (int, optional): Number of timestamp in data_dict. Defaults to 1.
370
370
"""
371
371
372
372
visu_data = [data_dict [k ] for k in visu_keys ]
373
373
if isinstance (visu_data [0 ], paddle .Tensor ):
374
374
visu_data = [x .numpy () for x in visu_data ]
375
375
376
- _save_plot_from_3d_array (filename , visu_data , visu_keys , num_timestamp )
376
+ _save_plot_from_3d_array (filename , visu_data , visu_keys , num_timestamps )
0 commit comments