77# granted to it by virtue of its status as an intergovernmental organisation
88# nor does it submit to any jurisdiction.
99
10- import argparse
1110import pdb
1211import sys
1312import time
1413import traceback
15- from pathlib import Path
16-
17- import pandas as pd
1814
15+ import weathergen .utils .cli as cli
1916import weathergen .utils .config as config
2017from weathergen .train .trainer import Trainer
2118from weathergen .utils .logger import init_loggers
@@ -33,189 +30,90 @@ def evaluate_from_args(argl: list[str]):
3330
3431 When running integration tests, the arguments are directly provided.
3532 """
36- parser = argparse .ArgumentParser ()
37-
38- parser .add_argument (
39- "--run_id" ,
40- type = str ,
41- required = True ,
42- help = "Run/model id of pretrained WeatherGenerator model." ,
43- )
44- parser .add_argument (
45- "--start_date" ,
46- "-start" ,
47- type = str ,
48- required = False ,
49- default = "2022-10-01" ,
50- help = "Start date for evaluation. Format must be parsable with pd.to_datetime." ,
51- )
52- parser .add_argument (
53- "--end_date" ,
54- "-end" ,
55- type = str ,
56- required = False ,
57- default = "2022-12-01" ,
58- help = "End date for evaluation. Format must be parsable with pd.to_datetime." ,
59- )
60- parser .add_argument (
61- "--epoch" ,
62- type = int ,
63- default = None ,
64- help = "Epoch of pretrained WeatherGenerator model used for evaluation (Default None corresponds to the last checkpoint)." ,
65- )
66- parser .add_argument (
67- "--forecast_steps" ,
68- type = int ,
69- default = None ,
70- help = "Number of forecast steps for evaluation. Uses attribute from config when None is set." ,
71- )
72- parser .add_argument (
73- "--samples" , type = int , default = 10000000 , help = "Number of evaluation samples."
74- )
75- parser .add_argument (
76- "--shuffle" , type = bool , default = False , help = "Shuffle samples from evaluation."
77- )
78- parser .add_argument (
79- "--save_samples" , type = bool , default = True , help = "Save samples from evaluation."
80- )
81- parser .add_argument (
82- "--analysis_streams_output" ,
83- type = list ,
84- default = ["ERA5" ],
85- help = "Analysis output streams during evaluation." ,
86- )
87- parser .add_argument (
88- "--private_config" ,
89- type = Path ,
90- default = None ,
91- help = "Path to private configuration file for paths." ,
92- )
93- parser .add_argument (
94- "--eval_run_id" ,
95- type = str ,
96- required = False ,
97- dest = "eval_run_id" ,
98- help = "(optional) if specified, uses the provided run id to store the evaluation results" ,
99- )
100- parser .add_argument (
101- "--config" ,
102- type = Path ,
103- default = None ,
104- help = "Optional experiment specfic configuration file" ,
105- )
106-
33+ parser = cli .get_evaluate_parser ()
10734 args = parser .parse_args (argl )
10835
10936 # TODO: move somewhere else
11037 init_loggers ()
11138
112- cf = config .load_config (args .private_config , args .run_id , args .epoch , args .config )
113-
114- cf .run_history += [(cf .run_id , cf .istep )]
115-
116- cf .samples_per_validation = args .samples
117- cf .log_validation = args .samples if args .save_samples else 0
118- start_date , end_date = pd .to_datetime (args .start_date ), pd .to_datetime (args .end_date )
119-
120- cf .start_date_val = start_date .strftime ("%Y%m%d%H%M" )
121- cf .end_date_val = end_date .strftime ("%Y%m%d%H%M" )
122-
123- cf .shuffle = args .shuffle
124-
125- cf .forecast_steps = args .forecast_steps if args .forecast_steps else cf .forecast_steps
126- # cf.forecast_policy = 'fixed'
39+ evaluate_overwrite = dict (
40+ shuffle = False ,
41+ start_date_val = args .start_date ,
42+ end_date_val = args .end_date ,
43+ samples_per_validation = args .samples ,
44+ log_validation = args .samples if args .save_samples else 0 ,
45+ analysis_streams_output = args .analysis_streams_output ,
46+ )
12747
128- # cf.analysis_streams_output = ['Surface', 'Air', 'METEOSAT', 'ATMS', 'IASI', 'AMSR2']
129- cf .analysis_streams_output = args .analysis_streams_output
48+ cli_overwrite = config .from_cli_arglist (args .options )
49+ cf = config .load_config (
50+ args .private_config ,
51+ args .from_run_id ,
52+ args .epoch ,
53+ * args .config ,
54+ evaluate_overwrite ,
55+ cli_overwrite ,
56+ )
13057
131- # make sure number of loaders does not exceed requested samples
132- cf .loader_num_workers = min (cf .loader_num_workers , args .samples )
58+ cf .run_history += [(args .from_run_id , cf .istep )]
13359
13460 trainer = Trainer ()
135- trainer .evaluate (cf , args .run_id , args .epoch , run_id_new = args .eval_run_id )
61+ trainer .evaluate (cf , args .from_run_id , args .epoch , run_id_new = args .run_id )
13662
13763
13864####################################################################################################
13965def train_continue () -> None :
140- parser = argparse .ArgumentParser ()
141-
142- parser .add_argument (
143- "-id" ,
144- "--run_id" ,
145- type = str ,
146- required = True ,
147- help = "run id of to be continued" ,
148- )
149- parser .add_argument (
150- "-e" ,
151- "--epoch" ,
152- type = int ,
153- required = False ,
154- default = - 1 ,
155- help = "epoch where to continue run" ,
156- )
157- parser .add_argument (
158- "-n" ,
159- "--run_id_new" ,
160- type = bool ,
161- required = False ,
162- default = False ,
163- help = "create new run id for cont'd run" ,
164- )
165- parser .add_argument (
166- "--private_config" ,
167- type = Path ,
168- default = None ,
169- help = "Path to private configuration file for paths." ,
170- )
171- parser .add_argument (
172- "--finetune_forecast" ,
173- action = "store_true" ,
174- help = "Fine tune for forecasting. It overwrites some of the Config settings." ,
175- )
176-
66+ parser = cli .get_continue_parser ()
17767 args = parser .parse_args ()
17868
179- cf = config .load_config (args .private_config , args .run_id , args .epoch , None )
69+ if args .finetune_forecast :
70+ finetune_overwrite = dict (
71+ training_mode = "forecast" ,
72+ forecast_delta_hrs = 0 , # 12
73+ forecast_steps = 1 , # [j for j in range(1,9) for i in range(4)]
74+ forecast_policy = "fixed" , # 'sequential_random' # 'fixed' #'sequential' #_random'
75+ forecast_freeze_model = True ,
76+ forecast_att_dense_rate = 1.0 , # 0.25
77+ fe_num_blocks = 8 ,
78+ fe_num_heads = 16 ,
79+ fe_dropout_rate = 0.1 ,
80+ fe_with_qk_lnorm = True ,
81+ lr_start = 0.000001 ,
82+ lr_max = 0.00003 ,
83+ lr_final_decay = 0.00003 ,
84+ lr_final = 0.0 ,
85+ lr_steps_warmup = 1024 ,
86+ lr_steps_cooldown = 4096 ,
87+ lr_policy_warmup = "cosine" ,
88+ lr_policy_decay = "linear" ,
89+ lr_policy_cooldown = "linear" ,
90+ num_epochs = 12 , # len(cf.forecast_steps) + 4
91+ istep = 0 ,
92+ )
93+ else :
94+ finetune_overwrite = dict ()
95+
96+ cli_overwrite = config .from_cli_arglist (args .options )
97+ cf = config .load_config (
98+ args .private_config ,
99+ args .from_run_id ,
100+ args .epoch ,
101+ finetune_overwrite ,
102+ * args .config ,
103+ cli_overwrite ,
104+ )
180105
181106 # track history of run to ensure traceability of results
182- cf .run_history += [(cf . run_id , cf .istep )]
107+ cf .run_history += [(args . from_run_id , cf .istep )]
183108
184- #########################
185109 if args .finetune_forecast :
186- cf .training_mode = "forecast"
187- cf .forecast_delta_hrs = 0 # 12
188- cf .forecast_steps = 1 # [j for j in range(1,9) for i in range(4)]
189- cf .forecast_policy = "fixed" # 'sequential_random' # 'fixed' #'sequential' #_random'
190- cf .forecast_freeze_model = True
191- cf .forecast_att_dense_rate = 1.0 # 0.25
192-
193110 if cf .forecast_freeze_model :
194111 cf .with_fsdp = False
195112 import torch
196113
197114 torch ._dynamo .config .optimize_ddp = False
198-
199- cf .fe_num_blocks = 8
200- cf .fe_num_heads = 16
201- cf .fe_dropout_rate = 0.1
202- cf .fe_with_qk_lnorm = True
203-
204- cf .lr_start = 0.000001
205- cf .lr_max = 0.00003
206- cf .lr_final_decay = 0.00003
207- cf .lr_final = 0.0
208- cf .lr_steps_warmup = 1024
209- cf .lr_steps_cooldown = 4096
210- cf .lr_policy_warmup = "cosine"
211- cf .lr_policy_decay = "linear"
212- cf .lr_policy_cooldown = "linear"
213-
214- cf .num_epochs = 12 # len(cf.forecast_steps) + 4
215- cf .istep = 0
216-
217115 trainer = Trainer ()
218- trainer .run (cf , args .run_id , args .epoch , args .run_id_new )
116+ trainer .run (cf , args .from_run_id , args .epoch , run_id_new = args .run_id )
219117
220118
221119####################################################################################################
@@ -236,34 +134,14 @@ def train() -> None:
236134def train_with_args (argl : list [str ], stream_dir : str | None ):
237135 """
238136 Training function for WeatherGenerator model."""
239- parser = argparse .ArgumentParser ()
240-
241- parser .add_argument (
242- "--run_id" ,
243- type = str ,
244- default = None ,
245- help = "Run id" ,
246- )
247- parser .add_argument (
248- "--private_config" ,
249- type = Path ,
250- default = None ,
251- help = "Path to private configuration file for paths" ,
252- )
253- parser .add_argument (
254- "--config" ,
255- type = Path ,
256- default = None ,
257- help = "Optional experiment specfic configuration file" ,
258- )
259-
137+ parser = cli .get_train_parser ()
260138 args = parser .parse_args (argl )
261139
262140 # TODO: move somewhere else
263141 init_loggers ()
264142
265- cf = config .load_config (args .private_config , None , None , args . config )
266- cf . run_id = args .run_id
143+ cli_overwrite = config .from_cli_arglist (args .options )
144+ cf = config . load_config ( args .private_config , None , None , * args . config , cli_overwrite )
267145
268146 if cf .with_flash_attention :
269147 assert cf .with_mixed_precision
0 commit comments