69
69
]
70
70
71
71
72
- def _save_plot_from_1d_array (filename , coord , value , value_keys , num_timestamp = 1 ):
72
+ def _save_plot_from_1d_array (filename , coord , value , value_keys , num_timestamps = 1 ):
73
73
"""Save plot from given 1D data.
74
74
75
75
Args:
76
76
filename (str): Filename.
77
77
coord (np.ndarray): Coordinate array.
78
78
value (Dict[str, np.ndarray]): Dict of value array.
79
79
value_keys (Tuple[str, ...]): Value keys.
80
- num_timestamp (int, optional): Number of timestamps coord/value contains. Defaults to 1.
80
+ num_timestamps (int, optional): Number of timestamps coord/value contains. Defaults to 1.
81
81
"""
82
- fig , a = plt .subplots (len (value_keys ), num_timestamp , squeeze = False )
82
+ fig , a = plt .subplots (len (value_keys ), num_timestamps , squeeze = False )
83
83
fig .subplots_adjust (hspace = 0.8 )
84
84
85
- len_ts = len (coord ) // num_timestamp
86
- for t in range (num_timestamp ):
85
+ len_ts = len (coord ) // num_timestamps
86
+ for t in range (num_timestamps ):
87
87
st = t * len_ts
88
88
ed = (t + 1 ) * len_ts
89
89
coord_t = coord [st :ed ]
@@ -96,29 +96,29 @@ def _save_plot_from_1d_array(filename, coord, value, value_keys, num_timestamp=1
96
96
color = cnames [i ],
97
97
label = key ,
98
98
)
99
- if num_timestamp > 1 :
99
+ if num_timestamps > 1 :
100
100
a [i ][t ].set_title (f"{ key } (t={ t } )" )
101
101
else :
102
102
a [i ][t ].set_title (f"{ key } " )
103
103
a [i ][t ].grid ()
104
104
a [i ][t ].legend ()
105
105
106
- if num_timestamp == 1 :
106
+ if num_timestamps == 1 :
107
107
fig .savefig (filename , dpi = 300 )
108
108
else :
109
109
fig .savefig (f"{ filename } _{ t } " , dpi = 300 )
110
110
111
- if num_timestamp == 1 :
111
+ if num_timestamps == 1 :
112
112
logger .info (f"1D result is saved to { filename } .png" )
113
113
else :
114
114
logger .info (
115
115
f"1D result is saved to { filename } _0.png"
116
- f" ~ { filename } _{ num_timestamp - 1 } .png"
116
+ f" ~ { filename } _{ num_timestamps - 1 } .png"
117
117
)
118
118
119
119
120
120
def save_plot_from_1d_dict (
121
- filename , data_dict , coord_keys , value_keys , num_timestamp = 1
121
+ filename , data_dict , coord_keys , value_keys , num_timestamps = 1
122
122
):
123
123
"""Plot dict data as file.
124
124
@@ -127,7 +127,7 @@ def save_plot_from_1d_dict(
127
127
data_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Data in dict.
128
128
coord_keys (Tuple[str, ...]): Tuple of coord key. such as ("x", "y").
129
129
value_keys (Tuple[str, ...]): Tuple of value key. such as ("u", "v").
130
- num_timestamp (int, optional): Number of timestamp in data_dict. Defaults to 1.
130
+ num_timestamps (int, optional): Number of timestamp in data_dict. Defaults to 1.
131
131
"""
132
132
space_ndim = len (coord_keys ) - int ("t" in coord_keys )
133
133
if space_ndim not in [1 , 2 , 3 ]:
@@ -149,14 +149,14 @@ def save_plot_from_1d_dict(
149
149
value = [x for x in value ]
150
150
value = np .concatenate (value , axis = 1 )
151
151
152
- _save_plot_from_1d_array (filename , coord , value , value_keys , num_timestamp )
152
+ _save_plot_from_1d_array (filename , coord , value , value_keys , num_timestamps )
153
153
154
154
155
155
def _save_plot_from_2d_array (
156
156
filename : str ,
157
157
visu_data : Tuple [np .ndarray , ...],
158
158
visu_keys : Tuple [str , ...],
159
- num_timestamp : int = 1 ,
159
+ num_timestamps : int = 1 ,
160
160
stride : int = 1 ,
161
161
xticks : Optional [Tuple [float , ...]] = None ,
162
162
yticks : Optional [Tuple [float , ...]] = None ,
@@ -167,7 +167,7 @@ def _save_plot_from_2d_array(
167
167
filename (str): Filename.
168
168
visu_data (Tuple[np.ndarray, ...]): Data that requires visualization.
169
169
visu_keys (Tuple[str, ...]]): Keys for visualizing data. such as ("u", "v").
170
- num_timestamp (int, optional): Number of timestamps coord/value contains. Defaults to 1.
170
+ num_timestamps (int, optional): Number of timestamps coord/value contains. Defaults to 1.
171
171
stride (int, optional): The time stride of visualization. Defaults to 1.
172
172
xticks (Optional[Tuple[float, ...]]): Tuple of xtick locations. Defaults to None.
173
173
yticks (Optional[Tuple[float, ...]]): Tuple of ytick locations. Defaults to None.
@@ -179,10 +179,10 @@ def _save_plot_from_2d_array(
179
179
180
180
fig , ax = plt .subplots (
181
181
len (visu_keys ),
182
- num_timestamp ,
182
+ num_timestamps ,
183
183
squeeze = False ,
184
184
sharey = True ,
185
- figsize = (num_timestamp , len (visu_keys )),
185
+ figsize = (num_timestamps , len (visu_keys )),
186
186
)
187
187
fig .subplots_adjust (hspace = 0.3 )
188
188
target_flag = any (["target" in key for key in visu_keys ])
@@ -191,7 +191,7 @@ def _save_plot_from_2d_array(
191
191
c_max = np .amax (data )
192
192
c_min = np .amin (data )
193
193
194
- for t_idx in range (num_timestamp ):
194
+ for t_idx in range (num_timestamps ):
195
195
t = t_idx * stride
196
196
ax [i , t_idx ].imshow (
197
197
data [t , :, :],
@@ -226,7 +226,7 @@ def save_plot_from_2d_dict(
226
226
filename : str ,
227
227
data_dict : Dict [str , Union [np .ndarray , paddle .Tensor ]],
228
228
visu_keys : Tuple [str , ...],
229
- num_timestamp : int = 1 ,
229
+ num_timestamps : int = 1 ,
230
230
stride : int = 1 ,
231
231
xticks : Optional [Tuple [float , ...]] = None ,
232
232
yticks : Optional [Tuple [float , ...]] = None ,
@@ -237,7 +237,7 @@ def save_plot_from_2d_dict(
237
237
filename (str): Output filename.
238
238
data_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Data in dict.
239
239
visu_keys (Tuple[str, ...]): Keys for visualizing data. such as ("u", "v").
240
- num_timestamp (int, optional): Number of timestamp in data_dict. Defaults to 1.
240
+ num_timestamps (int, optional): Number of timestamp in data_dict. Defaults to 1.
241
241
stride (int, optional): The time stride of visualization. Defaults to 1.
242
242
xticks (Optional[Tuple[float,...]]): The list of xtick locations. Defaults to None.
243
243
yticks (Optional[Tuple[float,...]]): The list of ytick locations. Defaults to None.
@@ -246,7 +246,7 @@ def save_plot_from_2d_dict(
246
246
if isinstance (visu_data [0 ], paddle .Tensor ):
247
247
visu_data = [x .numpy () for x in visu_data ]
248
248
_save_plot_from_2d_array (
249
- filename , visu_data , visu_keys , num_timestamp , stride , xticks , yticks
249
+ filename , visu_data , visu_keys , num_timestamps , stride , xticks , yticks
250
250
)
251
251
252
252
@@ -308,21 +308,21 @@ def _save_plot_from_3d_array(
308
308
filename : str ,
309
309
visu_data : Tuple [np .ndarray , ...],
310
310
visu_keys : Tuple [str , ...],
311
- num_timestamp : int = 1 ,
311
+ num_timestamps : int = 1 ,
312
312
):
313
313
"""Save plot from given 3D data.
314
314
315
315
Args:
316
316
filename (str): Filename.
317
317
visu_data (Tuple[np.ndarray, ...]): Data that requires visualization.
318
318
visu_keys (Tuple[str, ...]]): Keys for visualizing data. such as ("u", "v").
319
- num_timestamp (int, optional): Number of timestamps coord/value contains. Defaults to 1.
319
+ num_timestamps (int, optional): Number of timestamps coord/value contains. Defaults to 1.
320
320
"""
321
321
322
322
fig = plt .figure (figsize = (10 , 10 ))
323
- len_ts = len (visu_data [0 ]) // num_timestamp
324
- for t in range (num_timestamp ):
325
- ax = fig .add_subplot (1 , num_timestamp , t + 1 , projection = "3d" )
323
+ len_ts = len (visu_data [0 ]) // num_timestamps
324
+ for t in range (num_timestamps ):
325
+ ax = fig .add_subplot (1 , num_timestamps , t + 1 , projection = "3d" )
326
326
st = t * len_ts
327
327
ed = (t + 1 ) * len_ts
328
328
visu_data_t = [data [st :ed ] for data in visu_data ]
@@ -343,40 +343,40 @@ def _save_plot_from_3d_array(
343
343
loc = "upper right" ,
344
344
framealpha = 0.95 ,
345
345
)
346
- if num_timestamp == 1 :
346
+ if num_timestamps == 1 :
347
347
fig .savefig (filename , dpi = 300 )
348
348
else :
349
349
fig .savefig (f"{ filename } _{ t } " , dpi = 300 )
350
350
351
- if num_timestamp == 1 :
351
+ if num_timestamps == 1 :
352
352
logger .info (f"3D result is saved to { filename } .png" )
353
353
else :
354
354
logger .info (
355
355
f"3D result is saved to { filename } _0.png"
356
- f" ~ { filename } _{ num_timestamp - 1 } .png"
356
+ f" ~ { filename } _{ num_timestamps - 1 } .png"
357
357
)
358
358
359
359
360
360
def save_plot_from_3d_dict (
361
361
filename : str ,
362
362
data_dict : Dict [str , Union [np .ndarray , paddle .Tensor ]],
363
363
visu_keys : Tuple [str , ...],
364
- num_timestamp : int = 1 ,
364
+ num_timestamps : int = 1 ,
365
365
):
366
366
"""Plot dict data as file.
367
367
368
368
Args:
369
369
filename (str): Output filename.
370
370
data_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Data in dict.
371
371
visu_keys (Tuple[str, ...]): Keys for visualizing data. such as ("u", "v").
372
- num_timestamp (int, optional): Number of timestamp in data_dict. Defaults to 1.
372
+ num_timestamps (int, optional): Number of timestamp in data_dict. Defaults to 1.
373
373
"""
374
374
375
375
visu_data = [data_dict [k ] for k in visu_keys ]
376
376
if isinstance (visu_data [0 ], paddle .Tensor ):
377
377
visu_data = [x .numpy () for x in visu_data ]
378
378
379
- _save_plot_from_3d_array (filename , visu_data , visu_keys , num_timestamp )
379
+ _save_plot_from_3d_array (filename , visu_data , visu_keys , num_timestamps )
380
380
381
381
382
382
def _save_plot_weather_from_array (
0 commit comments