Skip to content

Commit 8d6c346

Browse files
authored
Sgrasse/develop/issue 185 (ecmwf#221)
* disable prefix parsing for CLI * remove run_id flag fresh training * separate cli definition into its own module * enable multiple config overwrites from different sources * add method to parse cli arglist into config * pack extra overwrites for forecasting into dict and supply as overwrite * add `--config` arg to train_continue to harmonize cli interfaces * pass overwrite constructed from additional cli args to load_config * make print_cf testable, filter secrets * change: raise ValueError for unknown type argument * add tests for the configuration * test config.load_streams, config.save * improve logging/error handling of load_streams * refactor: centralize naming of run config files * align semantics and default value of `--epoch` * unify of `--eval_run_id` and `--run_id_new`. - `--eval_run_id` (for evaluation) - `--run_id_new` (for continue) - `--run_id` (for train) are coalesced into `--run_id` which now indicates always the resulting `run_id` of a run. the previous `--run_id` flag used to specify a model to load is renamed into `--from_run_id`. * factor out shared cli argument definitions * remove evaluation args that can be overwritten using new mechanism * move liminting of workers by samples into evaluation method * make `--analysis_streams_output` and `--config` accept multiple items * clean up evaluate_from_args * change implementation of how to handle direct option overwrites: use argument --options * move date formatting to cli * make flags store_true and assure nargs='+' always returns list * improve typehints and cli help text
1 parent bfc669d commit 8d6c346

File tree

7 files changed

+646
-223
lines changed

7 files changed

+646
-223
lines changed

integration_tests/small1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ def test_train(setup, test_run_id):
5757
evaluate_from_args(
5858
"-start 2022-10-10 -end 2022-10-11 --samples 10 --epoch 0".split()
5959
+ [
60-
"--run_id",
60+
"--from_run_id",
6161
test_run_id,
62-
"--eval_run_id",
62+
"--run_id",
6363
test_run_id,
6464
"--config",
6565
f"{weathergen_home}/integration_tests/small1.yaml",

src/weathergen/__init__.py

Lines changed: 63 additions & 185 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,12 @@
77
# granted to it by virtue of its status as an intergovernmental organisation
88
# nor does it submit to any jurisdiction.
99

10-
import argparse
1110
import pdb
1211
import sys
1312
import time
1413
import traceback
15-
from pathlib import Path
16-
17-
import pandas as pd
1814

15+
import weathergen.utils.cli as cli
1916
import weathergen.utils.config as config
2017
from weathergen.train.trainer import Trainer
2118
from weathergen.utils.logger import init_loggers
@@ -33,189 +30,90 @@ def evaluate_from_args(argl: list[str]):
3330
3431
When running integration tests, the arguments are directly provided.
3532
"""
36-
parser = argparse.ArgumentParser()
37-
38-
parser.add_argument(
39-
"--run_id",
40-
type=str,
41-
required=True,
42-
help="Run/model id of pretrained WeatherGenerator model.",
43-
)
44-
parser.add_argument(
45-
"--start_date",
46-
"-start",
47-
type=str,
48-
required=False,
49-
default="2022-10-01",
50-
help="Start date for evaluation. Format must be parsable with pd.to_datetime.",
51-
)
52-
parser.add_argument(
53-
"--end_date",
54-
"-end",
55-
type=str,
56-
required=False,
57-
default="2022-12-01",
58-
help="End date for evaluation. Format must be parsable with pd.to_datetime.",
59-
)
60-
parser.add_argument(
61-
"--epoch",
62-
type=int,
63-
default=None,
64-
help="Epoch of pretrained WeatherGenerator model used for evaluation (Default None corresponds to the last checkpoint).",
65-
)
66-
parser.add_argument(
67-
"--forecast_steps",
68-
type=int,
69-
default=None,
70-
help="Number of forecast steps for evaluation. Uses attribute from config when None is set.",
71-
)
72-
parser.add_argument(
73-
"--samples", type=int, default=10000000, help="Number of evaluation samples."
74-
)
75-
parser.add_argument(
76-
"--shuffle", type=bool, default=False, help="Shuffle samples from evaluation."
77-
)
78-
parser.add_argument(
79-
"--save_samples", type=bool, default=True, help="Save samples from evaluation."
80-
)
81-
parser.add_argument(
82-
"--analysis_streams_output",
83-
type=list,
84-
default=["ERA5"],
85-
help="Analysis output streams during evaluation.",
86-
)
87-
parser.add_argument(
88-
"--private_config",
89-
type=Path,
90-
default=None,
91-
help="Path to private configuration file for paths.",
92-
)
93-
parser.add_argument(
94-
"--eval_run_id",
95-
type=str,
96-
required=False,
97-
dest="eval_run_id",
98-
help="(optional) if specified, uses the provided run id to store the evaluation results",
99-
)
100-
parser.add_argument(
101-
"--config",
102-
type=Path,
103-
default=None,
104-
help="Optional experiment specfic configuration file",
105-
)
106-
33+
parser = cli.get_evaluate_parser()
10734
args = parser.parse_args(argl)
10835

10936
# TODO: move somewhere else
11037
init_loggers()
11138

112-
cf = config.load_config(args.private_config, args.run_id, args.epoch, args.config)
113-
114-
cf.run_history += [(cf.run_id, cf.istep)]
115-
116-
cf.samples_per_validation = args.samples
117-
cf.log_validation = args.samples if args.save_samples else 0
118-
start_date, end_date = pd.to_datetime(args.start_date), pd.to_datetime(args.end_date)
119-
120-
cf.start_date_val = start_date.strftime("%Y%m%d%H%M")
121-
cf.end_date_val = end_date.strftime("%Y%m%d%H%M")
122-
123-
cf.shuffle = args.shuffle
124-
125-
cf.forecast_steps = args.forecast_steps if args.forecast_steps else cf.forecast_steps
126-
# cf.forecast_policy = 'fixed'
39+
evaluate_overwrite = dict(
40+
shuffle=False,
41+
start_date_val=args.start_date,
42+
end_date_val=args.end_date,
43+
samples_per_validation=args.samples,
44+
log_validation=args.samples if args.save_samples else 0,
45+
analysis_streams_output=args.analysis_streams_output,
46+
)
12747

128-
# cf.analysis_streams_output = ['Surface', 'Air', 'METEOSAT', 'ATMS', 'IASI', 'AMSR2']
129-
cf.analysis_streams_output = args.analysis_streams_output
48+
cli_overwrite = config.from_cli_arglist(args.options)
49+
cf = config.load_config(
50+
args.private_config,
51+
args.from_run_id,
52+
args.epoch,
53+
*args.config,
54+
evaluate_overwrite,
55+
cli_overwrite,
56+
)
13057

131-
# make sure number of loaders does not exceed requested samples
132-
cf.loader_num_workers = min(cf.loader_num_workers, args.samples)
58+
cf.run_history += [(args.from_run_id, cf.istep)]
13359

13460
trainer = Trainer()
135-
trainer.evaluate(cf, args.run_id, args.epoch, run_id_new=args.eval_run_id)
61+
trainer.evaluate(cf, args.from_run_id, args.epoch, run_id_new=args.run_id)
13662

13763

13864
####################################################################################################
13965
def train_continue() -> None:
140-
parser = argparse.ArgumentParser()
141-
142-
parser.add_argument(
143-
"-id",
144-
"--run_id",
145-
type=str,
146-
required=True,
147-
help="run id of to be continued",
148-
)
149-
parser.add_argument(
150-
"-e",
151-
"--epoch",
152-
type=int,
153-
required=False,
154-
default=-1,
155-
help="epoch where to continue run",
156-
)
157-
parser.add_argument(
158-
"-n",
159-
"--run_id_new",
160-
type=bool,
161-
required=False,
162-
default=False,
163-
help="create new run id for cont'd run",
164-
)
165-
parser.add_argument(
166-
"--private_config",
167-
type=Path,
168-
default=None,
169-
help="Path to private configuration file for paths.",
170-
)
171-
parser.add_argument(
172-
"--finetune_forecast",
173-
action="store_true",
174-
help="Fine tune for forecasting. It overwrites some of the Config settings.",
175-
)
176-
66+
parser = cli.get_continue_parser()
17767
args = parser.parse_args()
17868

179-
cf = config.load_config(args.private_config, args.run_id, args.epoch, None)
69+
if args.finetune_forecast:
70+
finetune_overwrite = dict(
71+
training_mode="forecast",
72+
forecast_delta_hrs=0, # 12
73+
forecast_steps=1, # [j for j in range(1,9) for i in range(4)]
74+
forecast_policy="fixed", # 'sequential_random' # 'fixed' #'sequential' #_random'
75+
forecast_freeze_model=True,
76+
forecast_att_dense_rate=1.0, # 0.25
77+
fe_num_blocks=8,
78+
fe_num_heads=16,
79+
fe_dropout_rate=0.1,
80+
fe_with_qk_lnorm=True,
81+
lr_start=0.000001,
82+
lr_max=0.00003,
83+
lr_final_decay=0.00003,
84+
lr_final=0.0,
85+
lr_steps_warmup=1024,
86+
lr_steps_cooldown=4096,
87+
lr_policy_warmup="cosine",
88+
lr_policy_decay="linear",
89+
lr_policy_cooldown="linear",
90+
num_epochs=12, # len(cf.forecast_steps) + 4
91+
istep=0,
92+
)
93+
else:
94+
finetune_overwrite = dict()
95+
96+
cli_overwrite = config.from_cli_arglist(args.options)
97+
cf = config.load_config(
98+
args.private_config,
99+
args.from_run_id,
100+
args.epoch,
101+
finetune_overwrite,
102+
*args.config,
103+
cli_overwrite,
104+
)
180105

181106
# track history of run to ensure traceability of results
182-
cf.run_history += [(cf.run_id, cf.istep)]
107+
cf.run_history += [(args.from_run_id, cf.istep)]
183108

184-
#########################
185109
if args.finetune_forecast:
186-
cf.training_mode = "forecast"
187-
cf.forecast_delta_hrs = 0 # 12
188-
cf.forecast_steps = 1 # [j for j in range(1,9) for i in range(4)]
189-
cf.forecast_policy = "fixed" # 'sequential_random' # 'fixed' #'sequential' #_random'
190-
cf.forecast_freeze_model = True
191-
cf.forecast_att_dense_rate = 1.0 # 0.25
192-
193110
if cf.forecast_freeze_model:
194111
cf.with_fsdp = False
195112
import torch
196113

197114
torch._dynamo.config.optimize_ddp = False
198-
199-
cf.fe_num_blocks = 8
200-
cf.fe_num_heads = 16
201-
cf.fe_dropout_rate = 0.1
202-
cf.fe_with_qk_lnorm = True
203-
204-
cf.lr_start = 0.000001
205-
cf.lr_max = 0.00003
206-
cf.lr_final_decay = 0.00003
207-
cf.lr_final = 0.0
208-
cf.lr_steps_warmup = 1024
209-
cf.lr_steps_cooldown = 4096
210-
cf.lr_policy_warmup = "cosine"
211-
cf.lr_policy_decay = "linear"
212-
cf.lr_policy_cooldown = "linear"
213-
214-
cf.num_epochs = 12 # len(cf.forecast_steps) + 4
215-
cf.istep = 0
216-
217115
trainer = Trainer()
218-
trainer.run(cf, args.run_id, args.epoch, args.run_id_new)
116+
trainer.run(cf, args.from_run_id, args.epoch, run_id_new=args.run_id)
219117

220118

221119
####################################################################################################
@@ -236,34 +134,14 @@ def train() -> None:
236134
def train_with_args(argl: list[str], stream_dir: str | None):
237135
"""
238136
Training function for WeatherGenerator model."""
239-
parser = argparse.ArgumentParser()
240-
241-
parser.add_argument(
242-
"--run_id",
243-
type=str,
244-
default=None,
245-
help="Run id",
246-
)
247-
parser.add_argument(
248-
"--private_config",
249-
type=Path,
250-
default=None,
251-
help="Path to private configuration file for paths",
252-
)
253-
parser.add_argument(
254-
"--config",
255-
type=Path,
256-
default=None,
257-
help="Optional experiment specfic configuration file",
258-
)
259-
137+
parser = cli.get_train_parser()
260138
args = parser.parse_args(argl)
261139

262140
# TODO: move somewhere else
263141
init_loggers()
264142

265-
cf = config.load_config(args.private_config, None, None, args.config)
266-
cf.run_id = args.run_id
143+
cli_overwrite = config.from_cli_arglist(args.options)
144+
cf = config.load_config(args.private_config, None, None, *args.config, cli_overwrite)
267145

268146
if cf.with_flash_attention:
269147
assert cf.with_mixed_precision

src/weathergen/train/trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,13 @@ def evaluate(self, cf, run_id_trained, epoch, run_id_new=False):
100100
shuffle=cf.shuffle,
101101
)
102102

103+
# make sure number of loaders does not exceed requested samples
104+
loader_num_workers = min(cf.samples_per_validation, cf.loader_num_workers)
103105
loader_params = {
104106
"batch_size": None,
105107
"batch_sampler": None,
106108
"shuffle": False,
107-
"num_workers": cf.loader_num_workers,
109+
"num_workers": loader_num_workers,
108110
"pin_memory": True,
109111
}
110112
self.data_loader_validation = torch.utils.data.DataLoader(

0 commit comments

Comments
 (0)